diff --git a/CHANGELOG.md b/CHANGELOG.md index ccd496ca3f..28f231e9fd 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,25 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Added +- scikit models + - Classifiers + - LogisticRegression + - GradientBoostingClassifier + - BernoulliNB + - ExtraTreesClassifier + - BaggingClassifier + - LinearDiscriminantAnalysis + - MultinomialNB + - Regressors + - ElasticNet + - BayesianRidge + - Lasso + - ARDRegression + - RANSACRegressor + - DecisionTreeRegressor + - GaussianProcessRegressor + - OrthogonalMatchingPursuit + - Lars - `AsyncExitStackTestCase` which instantiates and enters async and non-async `contextlib` exit stacks. Provides temporary file creation. ### Changed diff --git a/docs/plugins/dffml_model.rst b/docs/plugins/dffml_model.rst index 5bbc754886..f69346403d 100644 --- a/docs/plugins/dffml_model.rst +++ b/docs/plugins/dffml_model.rst @@ -427,6 +427,24 @@ Predicting with trained model: | Type | Model | Entrypoint | Parameters | +================+===============================+================+===============================================================================================================================================================================================+ | Regression | LinearRegression | scikitlr | `scikitlr `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | ElasticNet | scikiteln | `scikiteln `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | BayesianRidge | scikitbyr | `scikitbyr `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | Lasso | scikitlas | `scikitlas `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | ARDRegression | scikitard | `scikitard `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | RANSACRegressor | scikitrsc | `scikitrsc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | DecisionTreeRegressor | scikitdtr | `scikitdtr `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | GaussianProcessRegressor | scikitgpr | `scikitgpr `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | OrthogonalMatchingPursuit | scikitomp | `scikitomp `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | Lars | scikitlars | `scikitlars `_ | +----------------+-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Classification | KNeighborsClassifier | scikitknn | `scikitknn `_ | | +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -443,6 +461,22 @@ Predicting with trained model: | | MLPClassifier | scikitmlp | `scikitmlp `_ | | +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | | GaussianNB | scikitgnb | `scikitgnb `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | SVC | scikitsvc | `scikitsvc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | LogisticRegression | scikitlor | `scikitlor `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | GradientBoostingClassifier | scikitgbc | `scikitgbc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | BernoulliNB | scikitbnb | `scikitbnb `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | ExtraTreesClassifier | scikitetc | `scikitetc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | BaggingClassifier | scikitbgc | `scikitbgc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | LinearDiscriminantAnalysis | scikitlda | `scikitlda `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | MultinomialNB | scikitmnb | `scikitmnb `_ | +----------------+-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/model/scikit/dffml_model_scikit/__init__.py b/model/scikit/dffml_model_scikit/__init__.py index 7bea9e8af8..aa4e0eff65 100644 --- a/model/scikit/dffml_model_scikit/__init__.py +++ b/model/scikit/dffml_model_scikit/__init__.py @@ -52,6 +52,24 @@ | Type | Model | Entrypoint | Parameters | +================+===============================+================+===============================================================================================================================================================================================+ | Regression | LinearRegression | scikitlr | `scikitlr `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | ElasticNet | scikiteln | `scikiteln `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | BayesianRidge | scikitbyr | `scikitbyr `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | Lasso | scikitlas | `scikitlas `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | ARDRegression | scikitard | `scikitard `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | RANSACRegressor | scikitrsc | `scikitrsc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | DecisionTreeRegressor | scikitdtr | `scikitdtr `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | GaussianProcessRegressor | scikitgpr | `scikitgpr `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | OrthogonalMatchingPursuit | scikitomp | `scikitomp `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | Lars | scikitlars | `scikitlars `_ | +----------------+-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Classification | KNeighborsClassifier | scikitknn | `scikitknn `_ | | +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ @@ -68,6 +86,22 @@ | | MLPClassifier | scikitmlp | `scikitmlp `_ | | +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | | GaussianNB | scikitgnb | `scikitgnb `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | SVC | scikitsvc | `scikitsvc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | LogisticRegression | scikitlor | `scikitlor `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | GradientBoostingClassifier | scikitgbc | `scikitgbc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | BernoulliNB | scikitbnb | `scikitbnb `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | ExtraTreesClassifier | scikitetc | `scikitetc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | BaggingClassifier | scikitbgc | `scikitbgc `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | LinearDiscriminantAnalysis | scikitlda | `scikitlda `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | MultinomialNB | scikitmnb | `scikitmnb `_ | +----------------+-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/model/scikit/dffml_model_scikit/scikit_models.py b/model/scikit/dffml_model_scikit/scikit_models.py index 03c6e33382..97cca33384 100644 --- a/model/scikit/dffml_model_scikit/scikit_models.py +++ b/model/scikit/dffml_model_scikit/scikit_models.py @@ -13,12 +13,34 @@ from sklearn.neural_network import MLPClassifier from sklearn.neighbors import KNeighborsClassifier from sklearn.svm import SVC -from sklearn.gaussian_process import GaussianProcessClassifier -from sklearn.tree import DecisionTreeClassifier -from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier -from sklearn.naive_bayes import GaussianNB -from sklearn.discriminant_analysis import QuadraticDiscriminantAnalysis -from sklearn.linear_model import LinearRegression +from sklearn.gaussian_process import ( + GaussianProcessClassifier, + GaussianProcessRegressor, +) +from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor +from sklearn.ensemble import ( + RandomForestClassifier, + AdaBoostClassifier, + GradientBoostingClassifier, + ExtraTreesClassifier, + BaggingClassifier, +) +from sklearn.naive_bayes import GaussianNB, BernoulliNB, MultinomialNB +from sklearn.discriminant_analysis import ( + QuadraticDiscriminantAnalysis, + LinearDiscriminantAnalysis, +) +from sklearn.linear_model import ( + LinearRegression, + LogisticRegression, + ElasticNet, + BayesianRidge, + Lasso, + ARDRegression, + RANSACRegressor, + OrthogonalMatchingPursuit, + Lars, +) from dffml.util.cli.arg import Arg from dffml.util.entrypoint import entry_point @@ -88,6 +110,62 @@ class NoDefaultValue: applicable_features, ), ("scikitlr", "LinearRegression", LinearRegression, applicable_features), + ( + "scikitlor", + "LogisticRegression", + LogisticRegression, + applicable_features, + ), + ( + "scikitgbc", + "GradientBoostingClassifier", + GradientBoostingClassifier, + applicable_features, + ), + ( + "scikitetc", + "ExtraTreesClassifier", + ExtraTreesClassifier, + applicable_features, + ), + ( + "scikitbgc", + "BaggingClassifier", + BaggingClassifier, + applicable_features, + ), + ("scikiteln", "ElasticNet", ElasticNet, applicable_features,), + ("scikitbyr", "BayesianRidge", BayesianRidge, applicable_features,), + ("scikitlas", "Lasso", Lasso, applicable_features,), + ("scikitard", "ARDRegression", ARDRegression, applicable_features), + ("scikitrsc", "RANSACRegressor", RANSACRegressor, applicable_features), + ("scikitbnb", "BernoulliNB", BernoulliNB, applicable_features), + ("scikitmnb", "MultinomialNB", MultinomialNB, applicable_features), + ( + "scikitlda", + "LinearDiscriminantAnalysis", + LinearDiscriminantAnalysis, + applicable_features, + ), + ( + "scikitdtr", + "DecisionTreeRegressor", + DecisionTreeRegressor, + applicable_features, + ), + ( + "scikitgpr", + "GaussianProcessRegressor", + GaussianProcessRegressor, + applicable_features, + ), + ( + "scikitomp", + "OrthogonalMatchingPursuit", + OrthogonalMatchingPursuit, + applicable_features, + ), + ("scikitlars", "Lars", Lars, applicable_features), ]: parameters = inspect.signature(cls).parameters diff --git a/model/scikit/setup.py b/model/scikit/setup.py index ce1d9b1648..92c9ef4849 100644 --- a/model/scikit/setup.py +++ b/model/scikit/setup.py @@ -88,6 +88,22 @@ f"scikitqda = {IMPORT_NAME}.scikit_models:QuadraticDiscriminantAnalysisModel", f"scikitsvc = {IMPORT_NAME}.scikit_models:SVCModel", f"scikitlr = {IMPORT_NAME}.scikit_models:LinearRegressionModel", + f"scikitlor = {IMPORT_NAME}.scikit_models:LogisticRegressionModel", + f"scikitgbc = {IMPORT_NAME}.scikit_models:GradientBoostingClassifierModel", + f"scikitetc = {IMPORT_NAME}.scikit_models:ExtraTreesClassifierModel", + f"scikitbgc = {IMPORT_NAME}.scikit_models:BaggingClassifierModel", + f"scikiteln = {IMPORT_NAME}.scikit_models:ElasticNetModel", + f"scikitbyr = {IMPORT_NAME}.scikit_models:BayesianRidgeModel", + f"scikitlas = {IMPORT_NAME}.scikit_models:LassoModel", + f"scikitard = {IMPORT_NAME}.scikit_models:ARDRegressionModel", + f"scikitrsc = {IMPORT_NAME}.scikit_models:RANSACRegressorModel", + f"scikitbnb = {IMPORT_NAME}.scikit_models:BernoulliNBModel", + f"scikitmnb = {IMPORT_NAME}.scikit_models:MultinomialNBModel", + f"scikitlda = {IMPORT_NAME}.scikit_models:LinearDiscriminantAnalysisModel", + f"scikitdtr = {IMPORT_NAME}.scikit_models:DecisionTreeRegressorModel", + f"scikitgpr = {IMPORT_NAME}.scikit_models:GaussianProcessRegressorModel", + f"scikitomp = {IMPORT_NAME}.scikit_models:OrthogonalMatchingPursuitModel", + f"scikitlars = {IMPORT_NAME}.scikit_models:LarsModel", ] }, ) diff --git a/model/scikit/tests/test_scikit.py b/model/scikit/tests/test_scikit.py index 85bdf8592a..5c577f9092 100644 --- a/model/scikit/tests/test_scikit.py +++ b/model/scikit/tests/test_scikit.py @@ -188,9 +188,27 @@ async def test_02_predict(self): "AdaBoostClassifier", "GaussianNB", "QuadraticDiscriminantAnalysis", + "LogisticRegression", + "GradientBoostingClassifier", + "BernoulliNB", + "ExtraTreesClassifier", + "BaggingClassifier", + "LinearDiscriminantAnalysis", + "MultinomialNB", ] -REGRESSORS = ["LinearRegression"] +REGRESSORS = [ + "LinearRegression", + "ElasticNet", + "BayesianRidge", + "Lasso", + "ARDRegression", + "RANSACRegressor", + "DecisionTreeRegressor", + "GaussianProcessRegressor", + "OrthogonalMatchingPursuit", + "Lars", +] for clf in CLASSIFIERS: test_cls = type(