Skip to content

[ENH] introduce revised version of ETS #2834

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 23 commits into
base: main
Choose a base branch
from
Open

Conversation

TonyBagnall
Copy link
Contributor

@TonyBagnall TonyBagnall commented May 23, 2025

part of #2833

this replaces the current ETS with a numba version that is refactored. It is tested for equivalence, same tests pass and speed up over statsmodels is significant, cannot sensibly compare to statsforecast because that is tuning only, so obvs does more work, can compare that to autoETS

The first version was for some reason predicting an array rather than a float. Its still not clear to me how we use this with fit and predict, we might need a wrapper for an example. Basically fitting the model is so amazingly quick, refitting for each forecast is I believe worthwhile for interface compliance.

@aeon-actions-bot aeon-actions-bot bot added enhancement New feature, improvement request or other non-bug code enhancement forecasting Forecasting package labels May 23, 2025
@aeon-actions-bot
Copy link
Contributor

Thank you for contributing to aeon

I have added the following labels to this PR based on the title: [ $\color{#FEF1BE}{\textsf{enhancement}}$ ].
I have added the following labels to this PR based on the changes made: [ $\color{#31FCCD}{\textsf{forecasting}}$ ]. Feel free to change these if they do not properly represent the PR.

The Checks tab will show the status of our automated tests. You can click on individual test runs in the tab or "Details" in the panel below to see more information if there is a failure.

If our pre-commit code quality check fails, any trivial fixes will automatically be pushed to your PR unless it is a draft.

Don't hesitate to ask questions on the aeon Slack channel if you have any.

PR CI actions

These checkboxes will add labels to enable/disable CI functionality for this PR. This may not take effect immediately, and a new commit may be required to run the new configuration.

  • Run pre-commit checks for all files
  • Run mypy typecheck tests
  • Run all pytest tests and configurations
  • Run all notebook example tests
  • Run numba-disabled codecov tests
  • Stop automatic pre-commit fixes (always disabled for drafts)
  • Disable numba cache loading
  • Push an empty commit to re-run CI checks

@TonyBagnall TonyBagnall changed the title [ENH] first version for correctness testing [ENH] introduce numba version of ets May 23, 2025
@TonyBagnall
Copy link
Contributor Author

output of the numba identical and runtime timing for main and this version

def numba_compare():
    n =100
    import time
    import aeon.forecasting._ets as ets
    import aeon.forecasting._numba_ets as etsfast
    arr = np.random.rand(n)
    slow_ets = ets.ETSForecaster()
    fast_ets = etsfast.ETSForecaster()
    slow_ets.fit(arr)
    fast_ets.fit(arr)
    a1 = slow_ets.predict()
    a2 = fast_ets.predict()
    print(f" Slow = {a1} fast = {a2}")
    for n in range(100000,10000000,100000):
        arr = np.random.rand(n)

        ets_slow_time = time.time()
        slow_ets.fit(arr)
        ets_slow_time = time.time()-ets_slow_time
        ets_fast_time = time.time()
        fast_ets.fit(arr)
        ets_fast_time = time.time()-ets_fast_time
        ets_slow_pred = time.time()
        a1=slow_ets.predict()
        ets_slow_pred = time.time()-ets_slow_pred
        ets_fast_pred = time.time()
        a2=fast_ets.predict()
        ets_fast_pred = time.time()-ets_fast_pred
        equal = np.isclose(a1,a2, 4)
        print(f"{n},{ets_slow_time}, {ets_fast_time},{ets_slow_pred},,{ets_fast_pred}"
              f",{equal}")  # noqa

@TonyBagnall
Copy link
Contributor Author

TonyBagnall commented May 23, 2025

series length, train time, output the same

<style> </style>
2000000 4.12 0.04 [ True]
2100000 4.16 0.03 [ True]
2200000 4.23 0.03 [ True]
2300000 4.81 0.04 [ True]
2400000 5.11 0.04 [ True]
2500000 5.34 0.04 [ True]
2600000 5.27 0.04 [ True]
2700000 5.47 0.04 [ True]
2800000 5.86 0.05 [ True]
2900000 5.67 0.04 [ True]
3000000 6.43 0.04 [ True]
3100000 6.16 0.04 [ True]
3200000 6.08 0.05 [ True]
3300000 6.98 0.09 [ True]
3400000 8.88 0.07 [ True]
3500000 7.51 0.05 [ True]
3600000 7.28 0.05 [ True]
3700000 7.09 0.06 [ True]
3800000 7.44 0.08 [ True]
3900000 7.86 0.06 [ True]
4000000 8.47 0.08 [ True]
4100000 8.05 0.06 [ True]
4200000 7.95 0.06 [ True]
4300000 8.19 0.06 [ True]
4400000 8.40 0.06 [ True]
4500000 8.75 0.19 [ True]

@TonyBagnall
Copy link
Contributor Author

TonyBagnall commented May 23, 2025

much faster than statsmodels version, as equivalent as we can make it
statsmodels Version: 0.14.3

def statsmodels_compare(setup_func, random_seed, catch_errors):
    """Run both our statsforecast and our implementation and crosschecks."""
    import warnings
    warnings.filterwarnings("ignore")
    random.seed(random_seed)
    (
        y,
        m,
        error,
        trendtype,
        seasontype,
        alpha,
        beta,
        gamma,
        phi,
    ) = setup_func()
    # tsml-eval implementation
    start = time.perf_counter()
    f1 = etsfast.ETSForecaster(
        error,
        trendtype,
        seasontype,
        m,
        alpha,
        beta,
        gamma,
        phi,
        1,
    )
    for n in range(10000,1000000,10000):
        y = np.random.rand(n)
        start = time.perf_counter()
        f1 = etsfast.ETSForecaster(
            error,
            trendtype,
            seasontype,
            m,
            alpha,
            beta,
            gamma,
            phi,
            1,
        )
        f1.fit(y)
        aeon_time = time.perf_counter()-start
        from statsmodels.tsa.holtwinters import ExponentialSmoothing
        start = time.perf_counter()
        ets_model = ExponentialSmoothing(
            y[m:],
            trend="add" if trendtype == 1 else "mul" if trendtype == 2 else None,
            damped_trend=(phi != 1 and trendtype != 0),
            seasonal="add" if seasontype == 1 else "mul" if seasontype == 2 else None,
            seasonal_periods=m if seasontype != 0 else None,
            initialization_method="known",
            initial_level=f1.level_,
            initial_trend=f1.trend_ if trendtype != 0 else None,
            initial_seasonal=f1.seasonality_ if seasontype != 0 else None,
        )
        results = ets_model.fit(
            smoothing_level=alpha,
            smoothing_trend=(
                beta / alpha if trendtype != 0 else None
            ),  # statsmodels uses beta*=beta/alpha
            smoothing_seasonal=gamma if seasontype != 0 else None,
            damping_trend=phi if trendtype != 0 else None,
            optimized=False,
        )
        sm_time = time.perf_counter()-start
        print(f"{n},{aeon_time}, {sm_time:0.20f}")  # noqa

@TonyBagnall
Copy link
Contributor Author

length, aeon, statsmodels

<style> </style>
700000 0.011834 3.057566
710000 0.012088 3.139149
720000 0.013966 3.217208
730000 0.012504 3.63404
740000 0.015816 3.24126
750000 0.013041 3.542698
760000 0.013758 3.303388
770000 0.013628 3.537901
780000 0.013103 3.474043
790000 0.013888 3.391304
800000 0.01361 3.626867
810000 0.013806 3.450841
820000 0.013852 3.513971
830000 0.017863 3.660681
840000 0.014946 3.756113
850000 0.015857 3.80876
860000 0.022167 3.932306

@TonyBagnall
Copy link
Contributor Author

TonyBagnall commented May 23, 2025

it cannot be compared to the statsforecast model, because you cannot set alpha, beta, you can only auto tune them

I did though, its much faster because it doesnt tune :)

``python

def statsforecast_compare(setup_func, random_seed, catch_errors):
"""Run both our statsforecast and our implementation and crosschecks."""
import warnings
warnings.filterwarnings("ignore")
from statsforecast import StatsForecast
from statsforecast.models import ETS

random.seed(random_seed)
(
    y,
    m,
    error,
    trendtype,
    seasontype,
    alpha,
    beta,
    gamma,
    phi,
) = setup_func()
# tsml-eval implementation
start = time.perf_counter()
f1 = etsfast.ETSForecaster(
    error,
    trendtype,
    seasontype,
    m,
    alpha,
    beta,
    gamma,
    phi,
    1,
)
# set statsforecast parameters
# Map integer codes to model strings
error_str = {1: 'add', 2: 'mul'}[error]
trend_str = {0: None, 1: 'add', 2: 'mul'}[trendtype]
season_str = {0: None, 1: 'add', 2: 'mul'}[seasontype]
model_string = ''.join([c[0].upper() if c else 'N' for c in
                        [error_str, trend_str, season_str]])  # e.g., "AAN"

# Define StatsForecast object with a specific ETS model
f2 = StatsForecast(
    models=[ETS(model=model_string)],
    freq='M',
    n_jobs=1
)
print("n,aeon,statsforecast")
for n in range(10000,1000000,10000):
    y = np.random.rand(n)
    start = time.perf_counter()
    f1 = etsfast.ETSForecaster(
        error,
        trendtype,
        seasontype,
        m,
        alpha,
        beta,
        gamma,
        phi,
        1,
    )
    f1.fit(y)
    aeon_time = time.perf_counter()-start
  # Forecast the next step (h=1)
    start = time.perf_counter()
    df = pd.DataFrame({
        'unique_id': ['series_1'] * len(y),
        'ds': pd.date_range(start='2023-01-01', periods=len(y), freq='D'),
        'y': y
    })
    # Forecast 1 step ahead
    forecast_df = f2.forecast(df=df, h=1)
    sm_time = time.perf_counter()-start
    print(f"{n},{aeon_time}, {sm_time:0.20f}")  # noqa
        # print(f"Time for ETS: {time_fitets:0.20f}")  # noqa
        # print(f"Time for statsforecast ETS: {time_etscalc}")  # noqa
return True

@TonyBagnall TonyBagnall changed the title [ENH] introduce numba version of ets [ENH] introduce numba version of ETS May 23, 2025
@TonyBagnall TonyBagnall marked this pull request as ready for review May 23, 2025 19:43
Copy link
Member

@MatthewMiddlehurst MatthewMiddlehurst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current version appears to use numba already? Not quite sure what the code changes are. It does not appear to be just numbafying what is current there at least.

)


@njit(nogil=NOGIL, cache=CACHE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are these vars? we dont use them anywhere else.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were in the statsforecast version that this was based on, and I think I was messing with them to try and get it to produce the same output. nogil is set to the default value, so isn't necessary really, the cache one is not though, and apparently it speeds up compilation times by using the previously compiled function if available, so probably good to leave in!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll sere what happens if I change them to our defaults

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to fasthmath in numba for consistency with the rest of the package, it doesn;t seem to make a difference. Those vars were just constant booleans at the top of the file, but I dont think we need them so have gone to just True and False, again to match the package

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is any statsforecast code remaining? If so IMO it would be best to properly attribute it i.e. https://github.com/aeon-toolkit/aeon/pull/2748/files or att a note that this was inspired by that implementation.

@TonyBagnall TonyBagnall requested a review from alexbanwell1 May 24, 2025 13:34
@TonyBagnall
Copy link
Contributor Author

@alexbanwell1 could you look at this PR and address @MatthewMiddlehurst comments?

@TonyBagnall
Copy link
Contributor Author

@MatthewMiddlehurst this is just a reboot from alex's big PR, overrides previous version

@TonyBagnall TonyBagnall changed the title [ENH] introduce numba version of ETS [ENH] introduce revised version of ETS May 26, 2025
@TonyBagnall
Copy link
Contributor Author

TonyBagnall commented May 27, 2025

I'm not sure which variables we should make private (self._foo) and which we should {make set in fit} (self.foo_).

Copy link
Member

@MatthewMiddlehurst MatthewMiddlehurst left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just going to trust the with the fit/transform changes. All experimental anyway. Few more parameter suggestions.

Re: attributes, whatever you want people to be able to access and are willing to document really 🙂. Don't think there are any strict rules

@TonyBagnall
Copy link
Contributor Author

I have changed it so that you can input either strings or ints for error, trend and seasonality, these are now validated in fit and converted to ints for numba efficiency. Added a test of this conversion

Copy link
Member

@hadifawaz1999 hadifawaz1999 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great ! just some minor docs stuff from my side

Either NONE (0), ADDITIVE (1) or MULTIPLICATIVE (2).
seasonality_type : int, default = 0
Either NONE (0), ADDITIVE (1) or MULTIPLICATIVE (2).
error_type : string or in default='additive'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo "int" not "in"

Type of error model: 'additive' (0) or 'multiplicative' (1)
trend_type : string, int or None, default=None
Type of trend component: None (0), `additive' (1) or 'multiplicative' (2)
seasonality_type : string or None, default=None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

string, int or None no ? not only string or None

trend_type: int = NONE,
seasonality_type: int = NONE,
error_type: Union[int, str] = 1,
trend_type: Union[int, str, None] = 0,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if defaults here are ints it should be the same as in docs, because now one is string and one is int

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature, improvement request or other non-bug code enhancement forecasting Forecasting package
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants