diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 76c8aac3ed..27b7fda5c4 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -46,7 +46,7 @@ jobs: fail-fast: false matrix: os: [windows-latest, macos-latest, ubuntu-latest] - python-version: ['3.7', '3.8', '3.9'] # 3.10 once updated + python-version: ['3.8', '3.9', '3.10'] # 3.10 once updated kind: ['conda', 'source', 'dist'] exclude: @@ -60,15 +60,15 @@ jobs: - os: macos-latest include: - # Add the tag code-cov to ubuntu-3.7-source + # Add the tag code-cov to ubuntu-3.8-source - os: ubuntu-latest - python-version: 3.7 + python-version: 3.8 kind: 'source' code-cov: true - # Include one config with dist, ubuntu-3.7-dist + # Include one config with dist, ubuntu-3.8-dist - os: ubuntu-latest - python-version: 3.7 + python-version: 3.8 kind: 'dist' steps: diff --git a/autosklearn/automl.py b/autosklearn/automl.py index 93fde84330..fc7f94ee6f 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -42,11 +42,8 @@ from sklearn.dummy import DummyClassifier, DummyRegressor from sklearn.ensemble import VotingClassifier, VotingRegressor from sklearn.metrics._classification import type_of_target -from sklearn.model_selection._split import ( - BaseCrossValidator, - BaseShuffleSplit, - _RepeatedSplits, -) +from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit +from sklearn.model_selection._split import _RepeatedSplits from sklearn.pipeline import Pipeline from sklearn.utils import check_random_state from sklearn.utils.validation import check_is_fitted diff --git a/autosklearn/evaluation/__init__.py b/autosklearn/evaluation/__init__.py index ba17513ae0..7c7e15aa20 100644 --- a/autosklearn/evaluation/__init__.py +++ b/autosklearn/evaluation/__init__.py @@ -26,11 +26,8 @@ import numpy as np import pynisher from ConfigSpace import Configuration -from sklearn.model_selection._split import ( - BaseCrossValidator, - BaseShuffleSplit, - _RepeatedSplits, -) +from sklearn.model_selection import BaseCrossValidator, BaseShuffleSplit +from sklearn.model_selection._split import _RepeatedSplits from smac.runhistory.runhistory import RunInfo, RunValue from smac.stats.stats import Stats from smac.tae import StatusType, TAEAbortException diff --git a/autosklearn/evaluation/train_evaluator.py b/autosklearn/evaluation/train_evaluator.py index eb98fb8a2a..3f60e504bc 100644 --- a/autosklearn/evaluation/train_evaluator.py +++ b/autosklearn/evaluation/train_evaluator.py @@ -14,6 +14,7 @@ from sklearn.base import BaseEstimator from sklearn.model_selection import ( BaseCrossValidator, + BaseShuffleSplit, KFold, PredefinedSplit, ShuffleSplit, @@ -21,7 +22,7 @@ StratifiedShuffleSplit, train_test_split, ) -from sklearn.model_selection._split import BaseShuffleSplit, _RepeatedSplits +from sklearn.model_selection._split import _RepeatedSplits from smac.tae import StatusType, TAEAbortException from autosklearn.automl_common.common.utils.backend import Backend diff --git a/autosklearn/pipeline/components/base.py b/autosklearn/pipeline/components/base.py index 7b496842b2..a9d10d8c48 100644 --- a/autosklearn/pipeline/components/base.py +++ b/autosklearn/pipeline/components/base.py @@ -382,6 +382,7 @@ def __init__( # necessary to do this upon the construction of this object # self.set_hyperparameters(self.configuration) self.choice = None + self._fitted = False def get_components(cls): raise NotImplementedError() @@ -466,11 +467,13 @@ def get_hyperparameter_search_space( raise NotImplementedError() def fit(self, X, y, **kwargs): - # Allows to use check_is_fitted on the choice object - self.fitted_ = True + self._fitted = True if kwargs is None: kwargs = {} return self.choice.fit(X, y, **kwargs) + def __sklearn_is_fitted__(self) -> bool: + return self._fitted + def predict(self, X): return self.choice.predict(X) diff --git a/autosklearn/pipeline/components/classification/__init__.py b/autosklearn/pipeline/components/classification/__init__.py index 31fa2ea9ca..84434d4b1d 100644 --- a/autosklearn/pipeline/components/classification/__init__.py +++ b/autosklearn/pipeline/components/classification/__init__.py @@ -167,7 +167,7 @@ def get_current_iter(self): def iterative_fit(self, X, y, n_iter=1, **fit_params): # Allows to use check_is_fitted on the choice object - self.fitted_ = True + self._fitted = True if fit_params is None: fit_params = {} return self.choice.iterative_fit(X, y, n_iter=n_iter, **fit_params) diff --git a/autosklearn/pipeline/components/classification/decision_tree.py b/autosklearn/pipeline/components/classification/decision_tree.py index 1369ecf906..51e11ad5ea 100644 --- a/autosklearn/pipeline/components/classification/decision_tree.py +++ b/autosklearn/pipeline/components/classification/decision_tree.py @@ -114,6 +114,8 @@ def get_hyperparameter_search_space( ): cs = ConfigurationSpace() + # Criterion has now `log_loss` but it is equivalent to entropy, leave it out + # so as to not confuse the optimizer criterion = CategoricalHyperparameter( "criterion", ["gini", "entropy"], default_value="gini" ) diff --git a/autosklearn/pipeline/components/classification/extra_trees.py b/autosklearn/pipeline/components/classification/extra_trees.py index 36edd82584..3829d377b9 100644 --- a/autosklearn/pipeline/components/classification/extra_trees.py +++ b/autosklearn/pipeline/components/classification/extra_trees.py @@ -164,6 +164,9 @@ def get_hyperparameter_search_space( ): cs = ConfigurationSpace() + # There is also the `criterion` called `log_loss`, however the documation states + # they are equivalent. We leave one of them out so the optimizer does not need + # to worry about it criterion = CategoricalHyperparameter( "criterion", ["gini", "entropy"], default_value="gini" ) diff --git a/autosklearn/pipeline/components/classification/gradient_boosting.py b/autosklearn/pipeline/components/classification/gradient_boosting.py index 618028dff7..bc32dc0dab 100644 --- a/autosklearn/pipeline/components/classification/gradient_boosting.py +++ b/autosklearn/pipeline/components/classification/gradient_boosting.py @@ -189,7 +189,7 @@ def get_hyperparameter_search_space( feat_type: Optional[FEAT_TYPE_TYPE] = None, dataset_properties=None ): cs = ConfigurationSpace() - loss = Constant("loss", "auto") + loss = Constant("loss", "log_loss") learning_rate = UniformFloatHyperparameter( name="learning_rate", lower=0.01, upper=1, default_value=0.1, log=True ) diff --git a/autosklearn/pipeline/components/classification/random_forest.py b/autosklearn/pipeline/components/classification/random_forest.py index 892d8611d5..f3d1edebe1 100644 --- a/autosklearn/pipeline/components/classification/random_forest.py +++ b/autosklearn/pipeline/components/classification/random_forest.py @@ -78,7 +78,7 @@ def iterative_fit(self, X, y, sample_weight=None, n_iter=1, refit=False): self.min_samples_leaf = int(self.min_samples_leaf) self.min_weight_fraction_leaf = float(self.min_weight_fraction_leaf) - if self.max_features not in ("sqrt", "log2", "auto"): + if self.max_features not in ("sqrt", "log2"): max_features = int(X.shape[1] ** float(self.max_features)) else: max_features = self.max_features @@ -156,6 +156,9 @@ def get_hyperparameter_search_space( feat_type: Optional[FEAT_TYPE_TYPE] = None, dataset_properties=None ): cs = ConfigurationSpace() + # There is also the `criterion` called `log_loss`, however the documation states + # they are equivalent. We leave one of them out so the optimizer does not need + # to worry about it criterion = CategoricalHyperparameter( "criterion", ["gini", "entropy"], default_value="gini" ) diff --git a/autosklearn/pipeline/components/classification/sgd.py b/autosklearn/pipeline/components/classification/sgd.py index 5073f8ec20..a8ff60bfb5 100644 --- a/autosklearn/pipeline/components/classification/sgd.py +++ b/autosklearn/pipeline/components/classification/sgd.py @@ -179,8 +179,8 @@ def get_hyperparameter_search_space( loss = CategoricalHyperparameter( "loss", - ["hinge", "log", "modified_huber", "squared_hinge", "perceptron"], - default_value="log", + ["hinge", "log_log", "modified_huber", "squared_hinge", "perceptron"], + default_value="log_log", ) penalty = CategoricalHyperparameter( "penalty", ["l1", "l2", "elasticnet"], default_value="l2" diff --git a/autosklearn/pipeline/components/feature_preprocessing/extra_trees_preproc_for_classification.py b/autosklearn/pipeline/components/feature_preprocessing/extra_trees_preproc_for_classification.py index 904004b201..ff240b0397 100644 --- a/autosklearn/pipeline/components/feature_preprocessing/extra_trees_preproc_for_classification.py +++ b/autosklearn/pipeline/components/feature_preprocessing/extra_trees_preproc_for_classification.py @@ -132,6 +132,9 @@ def get_hyperparameter_search_space( cs = ConfigurationSpace() n_estimators = Constant("n_estimators", 100) + # There is also the `criterion` called `log_loss`, however the documation states + # they are equivalent. We leave one of them out so the optimizer does not need + # to worry about it criterion = CategoricalHyperparameter( "criterion", ["gini", "entropy"], default_value="gini" ) diff --git a/autosklearn/pipeline/components/feature_preprocessing/extra_trees_preproc_for_regression.py b/autosklearn/pipeline/components/feature_preprocessing/extra_trees_preproc_for_regression.py index 10e741a44e..8131cc4058 100644 --- a/autosklearn/pipeline/components/feature_preprocessing/extra_trees_preproc_for_regression.py +++ b/autosklearn/pipeline/components/feature_preprocessing/extra_trees_preproc_for_regression.py @@ -135,7 +135,7 @@ def get_hyperparameter_search_space( n_estimators = Constant("n_estimators", 100) criterion = CategoricalHyperparameter( - "criterion", ["mse", "friedman_mse", "mae"] + "criterion", ["squared_error", "friedman_mse", "mae"] ) max_features = UniformFloatHyperparameter( "max_features", 0.1, 1.0, default_value=1.0 diff --git a/autosklearn/pipeline/components/feature_preprocessing/kernel_pca.py b/autosklearn/pipeline/components/feature_preprocessing/kernel_pca.py index 08c72efb6f..351ad1a2c3 100644 --- a/autosklearn/pipeline/components/feature_preprocessing/kernel_pca.py +++ b/autosklearn/pipeline/components/feature_preprocessing/kernel_pca.py @@ -50,10 +50,10 @@ def fit(self, X, Y=None): with warnings.catch_warnings(): warnings.filterwarnings("error") self.preprocessor.fit(X) - # Raise an informative error message, equation is based ~line 249 in - # kernel_pca.py in scikit-learn - if len(self.preprocessor.alphas_ / self.preprocessor.lambdas_) == 0: + + if self.preprocessor._n_features_out == 0: raise ValueError("KernelPCA removed all features!") + return self def transform(self, X): diff --git a/autosklearn/pipeline/components/regression/__init__.py b/autosklearn/pipeline/components/regression/__init__.py index 9d1ef58650..8ab949fd6f 100644 --- a/autosklearn/pipeline/components/regression/__init__.py +++ b/autosklearn/pipeline/components/regression/__init__.py @@ -152,7 +152,7 @@ def get_current_iter(self): def iterative_fit(self, X, y, n_iter=1, **fit_params): # Allows to use check_is_fitted on the choice object - self.fitted_ = True + self._fitted = True if fit_params is None: fit_params = {} return self.choice.iterative_fit(X, y, n_iter=n_iter, **fit_params) diff --git a/autosklearn/pipeline/components/regression/ard_regression.py b/autosklearn/pipeline/components/regression/ard_regression.py index 758c4b04d7..0e4c1058cb 100644 --- a/autosklearn/pipeline/components/regression/ard_regression.py +++ b/autosklearn/pipeline/components/regression/ard_regression.py @@ -59,7 +59,6 @@ def fit(self, X, y): compute_score=False, threshold_lambda=self.threshold_lambda, fit_intercept=True, - normalize=False, copy_X=False, verbose=False, ) diff --git a/autosklearn/pipeline/components/regression/decision_tree.py b/autosklearn/pipeline/components/regression/decision_tree.py index 80890889f9..d2a9456bc9 100644 --- a/autosklearn/pipeline/components/regression/decision_tree.py +++ b/autosklearn/pipeline/components/regression/decision_tree.py @@ -105,7 +105,7 @@ def get_hyperparameter_search_space( cs = ConfigurationSpace() criterion = CategoricalHyperparameter( - "criterion", ["mse", "friedman_mse", "mae"] + "criterion", ["squared_error", "friedman_mse", "absolute_error"] ) max_features = Constant("max_features", 1.0) max_depth_factor = UniformFloatHyperparameter( diff --git a/autosklearn/pipeline/components/regression/extra_trees.py b/autosklearn/pipeline/components/regression/extra_trees.py index b1d8eeb00a..2d77a34e79 100644 --- a/autosklearn/pipeline/components/regression/extra_trees.py +++ b/autosklearn/pipeline/components/regression/extra_trees.py @@ -157,7 +157,7 @@ def get_hyperparameter_search_space( cs = ConfigurationSpace() criterion = CategoricalHyperparameter( - "criterion", ["mse", "friedman_mse", "mae"] + "criterion", ["squared_error", "friedman_mse", "absolute_error"] ) max_features = UniformFloatHyperparameter( "max_features", 0.1, 1.0, default_value=1 diff --git a/autosklearn/pipeline/components/regression/gradient_boosting.py b/autosklearn/pipeline/components/regression/gradient_boosting.py index 16b7df965d..5e4115b0f7 100644 --- a/autosklearn/pipeline/components/regression/gradient_boosting.py +++ b/autosklearn/pipeline/components/regression/gradient_boosting.py @@ -174,7 +174,7 @@ def get_hyperparameter_search_space( ): cs = ConfigurationSpace() loss = CategoricalHyperparameter( - "loss", ["least_squares"], default_value="least_squares" + "loss", ["squared_error"], default_value="squared_error" ) learning_rate = UniformFloatHyperparameter( name="learning_rate", lower=0.01, upper=1, default_value=0.1, log=True diff --git a/autosklearn/pipeline/components/regression/libsvm_svr.py b/autosklearn/pipeline/components/regression/libsvm_svr.py index c3ac98b1f9..a01a3f9352 100644 --- a/autosklearn/pipeline/components/regression/libsvm_svr.py +++ b/autosklearn/pipeline/components/regression/libsvm_svr.py @@ -127,6 +127,9 @@ def predict(self, X): raise NotImplementedError y_pred = self.estimator.predict(X) + if y_pred.ndim == 1: + y_pred = y_pred.reshape((-1, 1)) + inverse = self.scaler.inverse_transform(y_pred) # Flatten: [[0], [0], [0]] -> [0, 0, 0] diff --git a/autosklearn/pipeline/components/regression/mlp.py b/autosklearn/pipeline/components/regression/mlp.py index 42ceff4556..574c944880 100644 --- a/autosklearn/pipeline/components/regression/mlp.py +++ b/autosklearn/pipeline/components/regression/mlp.py @@ -204,6 +204,9 @@ def predict(self, X): y_pred = self.estimator.predict(X) + if y_pred.ndim == 1: + y_pred = y_pred.reshape((-1, 1)) + inverse = self.scaler.inverse_transform(y_pred) # Flatten: [[0], [0], [0]] -> [0, 0, 0] diff --git a/autosklearn/pipeline/components/regression/random_forest.py b/autosklearn/pipeline/components/regression/random_forest.py index 043d62e16b..5cf5ab1f38 100644 --- a/autosklearn/pipeline/components/regression/random_forest.py +++ b/autosklearn/pipeline/components/regression/random_forest.py @@ -143,7 +143,7 @@ def get_hyperparameter_search_space( ): cs = ConfigurationSpace() criterion = CategoricalHyperparameter( - "criterion", ["mse", "friedman_mse", "mae"] + "criterion", ["squared_error", "friedman_mse", "absolute_error"] ) # In contrast to the random forest classifier we want to use more max_features diff --git a/autosklearn/pipeline/components/regression/sgd.py b/autosklearn/pipeline/components/regression/sgd.py index 915e45169f..7de1ecbcc4 100644 --- a/autosklearn/pipeline/components/regression/sgd.py +++ b/autosklearn/pipeline/components/regression/sgd.py @@ -168,8 +168,19 @@ def configuration_fully_fitted(self): def predict(self, X): if self.estimator is None: raise NotImplementedError() - Y_pred = self.estimator.predict(X) - return self.scaler.inverse_transform(Y_pred) + + y_pred = self.estimator.predict(X) + + if y_pred.ndim == 1: + y_pred = y_pred.reshape((-1, 1)) + + inverse = self.scaler.inverse_transform(y_pred) + + # Flatten: [[0], [0], [0]] -> [0, 0, 0] + if inverse.ndim == 2 and inverse.shape[1] == 1: + inverse = inverse.flatten() + + return inverse @staticmethod def get_properties(dataset_properties=None): @@ -196,12 +207,12 @@ def get_hyperparameter_search_space( loss = CategoricalHyperparameter( "loss", [ - "squared_loss", + "squared_error", "huber", "epsilon_insensitive", "squared_epsilon_insensitive", ], - default_value="squared_loss", + default_value="squared_error", ) penalty = CategoricalHyperparameter( "penalty", ["l1", "l2", "elasticnet"], default_value="l2" diff --git a/autosklearn/util/__init__.py b/autosklearn/util/__init__.py index 9f2d05ccd5..44aa08ded1 100644 --- a/autosklearn/util/__init__.py +++ b/autosklearn/util/__init__.py @@ -1,7 +1,7 @@ # -*- encoding: utf-8 -*- import re -SUBPATTERN = r"((?P==|>=|>|<)(?P(\d+)?(\.[a-zA-Z0-9]+)?(\.[a-zA-Z0-9]+)?))" # noqa: E501 +SUBPATTERN = r"((?P==|>=|>|<|<=)(?P(\d+)?(\.[a-zA-Z0-9]+)?(\.[a-zA-Z0-9]+)?))" # noqa: E501 RE_PATTERN = re.compile( r"^(?P[\w\-]+)%s?(,%s)?$" % (SUBPATTERN % (1, 1), SUBPATTERN % (2, 2)) ) diff --git a/doc/index.rst b/doc/index.rst index 8b8a2b5bb4..ec9f3a0509 100644 --- a/doc/index.rst +++ b/doc/index.rst @@ -25,6 +25,14 @@ the technology behind *auto-sklearn* by reading our paper published at `NeurIPS 2015 `_ . + +.. topic:: Python3.7 + With the update to `scikit-learn 1.2 _`, + we are forced to drop `Python3.7` from version `0.16.0` onwards. + This means that that Google Colab, who only support `Python3.7` will no longer be able to + run the latest versions of AutoSklearn. Please use `0.15.0` if colab usage is required. + + .. topic:: NEW: Text feature support Auto-sklearn now supports text features, check our new example: diff --git a/pyproject.toml b/pyproject.toml index a696c0fb46..f304edfce9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ [tool.pytest.ini_options] testpaths = ["test"] -minversion = "3.7" +minversion = "3.8" addopts = "--forked" [tool.coverage.run] @@ -21,10 +21,10 @@ exclude_lines = [ ] [tool.black] -target-version = ['py37'] +target-version = ['py38'] [tool.isort] -py_version = "37" +py_version = "38" profile = "black" # Play nicely with black src_paths = ["autosklearn", "test"] known_types = ["typing", "abc"] # We put these in their own section TYPES @@ -66,7 +66,7 @@ add-ignore = [ # http://www.pydocstyle.org/en/stable/error_codes.html ] [tool.mypy] -python_version = "3.7" +python_version = "3.8" show_error_codes = true diff --git a/requirements.txt b/requirements.txt index d47fb91474..5d821b1291 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,7 +6,7 @@ numpy>=1.9.0 scipy>=1.7.0 joblib -scikit-learn>=0.24.0,<0.25.0 +scikit-learn>=1.1.3,<1.2 dask>=2021.12 distributed>=2012.12 @@ -19,4 +19,4 @@ tqdm ConfigSpace>=0.4.21,<0.5 pynisher>=0.6.3,<0.7 pyrfr>=0.8.1,<0.9 -smac>=1.2,<1.3 \ No newline at end of file +smac>=1.2,<1.3 diff --git a/scripts/readme.md b/scripts/readme.md index 5cf36486e7..cd8242f1e4 100644 --- a/scripts/readme.md +++ b/scripts/readme.md @@ -9,22 +9,13 @@ The working directory will be used to save all temporary and final output. working_directory=~/auto-sklearn-metadata/001 mkdir -p $working_directory -The task type defines whether you want update classification or regression -metadata: - - task_type=classification - -or - - task_type=regression - ## 2. Install the OpenML package and create an OpenML account Read the [OpenML python package manual](https://openml.github.io/openml-python) for this. ## 3. Create configuration commands - python3 01_create_commands.py --working-directory $working_directory --task-type $task_type + python3 01_create_commands.py --working-directory $working_directory This will create a file with all commands necessary to run auto-sklearn on a large number of datasets from OpenML. You can change the task IDs or the way diff --git a/setup.py b/setup.py index 6e37e0e711..86361fa148 100644 --- a/setup.py +++ b/setup.py @@ -14,10 +14,10 @@ % sys.platform ) -if sys.version_info < (3, 7): +if sys.version_info < (3, 8): raise ValueError( "Unsupported Python version %d.%d.%d found. Auto-sklearn requires Python " - "3.7 or higher." + "3.8 or higher." % (sys.version_info.major, sys.version_info.minor, sys.version_info.micro) ) @@ -91,10 +91,10 @@ "Operating System :: OS Independent", "Topic :: Scientific/Engineering :: Artificial Intelligence", "Topic :: Scientific/Engineering :: Information Analysis", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", ], - python_requires=">=3.7", + python_requires=">=3.8", url="https://automl.github.io/auto-sklearn", ) diff --git a/test/test_pipeline/components/regression/test_gradient_boosting.py b/test/test_pipeline/components/regression/test_gradient_boosting.py index 6412fd0598..19e6b4636a 100644 --- a/test/test_pipeline/components/regression/test_gradient_boosting.py +++ b/test/test_pipeline/components/regression/test_gradient_boosting.py @@ -21,5 +21,5 @@ class GradientBoostingComponentTest(BaseRegressionComponentTest): res["default_diabetes_sparse"] = None res["diabetes_n_call"] = 11 - sk_mod = sklearn.ensemble.GradientBoostingRegressor + sk_mod = sklearn.ensemble.HistGradientBoostingRegressor module = GradientBoosting