diff --git a/CHANGELOG.md b/CHANGELOG.md index 28f231e9fd..d24194ed8f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - GaussianProcessRegressor - OrthogonalMatchingPursuit - Lars + - Ridge - `AsyncExitStackTestCase` which instantiates and enters async and non-async `contextlib` exit stacks. Provides temporary file creation. ### Changed diff --git a/docs/plugins/dffml_model.rst b/docs/plugins/dffml_model.rst index f69346403d..67874b549f 100644 --- a/docs/plugins/dffml_model.rst +++ b/docs/plugins/dffml_model.rst @@ -445,6 +445,8 @@ Predicting with trained model: | | OrthogonalMatchingPursuit | scikitomp | `scikitomp `_ | | +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | | Lars | scikitlars | `scikitlars `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | Ridge | scikitridge | `scikitridge `_ | +----------------+-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Classification | KNeighborsClassifier | scikitknn | `scikitknn `_ | | +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/model/scikit/dffml_model_scikit/__init__.py b/model/scikit/dffml_model_scikit/__init__.py index aa4e0eff65..5ed1ed0ba9 100644 --- a/model/scikit/dffml_model_scikit/__init__.py +++ b/model/scikit/dffml_model_scikit/__init__.py @@ -70,6 +70,8 @@ | | OrthogonalMatchingPursuit | scikitomp | `scikitomp `_ | | +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | | Lars | scikitlars | `scikitlars `_ | +| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| | Ridge | scikitridge | `scikitridge `_ | +----------------+-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ | Classification | KNeighborsClassifier | scikitknn | `scikitknn `_ | | +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/model/scikit/dffml_model_scikit/scikit_models.py b/model/scikit/dffml_model_scikit/scikit_models.py index 97cca33384..a0fb695ef2 100644 --- a/model/scikit/dffml_model_scikit/scikit_models.py +++ b/model/scikit/dffml_model_scikit/scikit_models.py @@ -40,6 +40,7 @@ RANSACRegressor, OrthogonalMatchingPursuit, Lars, + Ridge, ) from dffml.util.cli.arg import Arg @@ -165,6 +166,7 @@ class NoDefaultValue: OrthogonalMatchingPursuit, applicable_features, ), + ("scikitridge", "Ridge", Ridge, applicable_features), ("scikitlars", "Lars", Lars, applicable_features), ]: diff --git a/model/scikit/setup.py b/model/scikit/setup.py index 92c9ef4849..9a99684ac3 100644 --- a/model/scikit/setup.py +++ b/model/scikit/setup.py @@ -104,6 +104,7 @@ f"scikitgpr = {IMPORT_NAME}.scikit_models:GaussianProcessRegressorModel", f"scikitomp = {IMPORT_NAME}.scikit_models:OrthogonalMatchingPursuitModel", f"scikitlars = {IMPORT_NAME}.scikit_models:LarsModel", + f"scikitridge = {IMPORT_NAME}.scikit_models:RidgeModel", ] }, ) diff --git a/model/scikit/tests/test_scikit.py b/model/scikit/tests/test_scikit.py index 5c577f9092..d02c769965 100644 --- a/model/scikit/tests/test_scikit.py +++ b/model/scikit/tests/test_scikit.py @@ -103,9 +103,9 @@ async def test_02_predict(self): elif self.MODEL_TYPE is "REGRESSION": correct = FEATURE_DATA_REGRESSION[int(repo.src_url)][3] self.assertGreater( - prediction, correct - (correct * 0.20) + prediction, correct - (correct * 0.40) ) - self.assertLess(prediction, correct + (correct * 0.20)) + self.assertLess(prediction, correct + (correct * 0.40)) FEATURE_DATA_CLASSIFICATION = [ @@ -208,6 +208,7 @@ async def test_02_predict(self): "GaussianProcessRegressor", "OrthogonalMatchingPursuit", "Lars", + "Ridge", ] for clf in CLASSIFIERS: diff --git a/tests/integration/test_models.py b/tests/integration/test_models.py index bbabbcfa6f..310b649d12 100644 --- a/tests/integration/test_models.py +++ b/tests/integration/test_models.py @@ -121,3 +121,99 @@ async def test_run(self): self.assertIn("value", results) results = results["value"] self.assertEqual(4, results) + + +class TestScikitRegression(IntegrationCLITestCase): + async def test_run(self): + self.required_plugins("dffml-model-scikit") + # Create the training data + train_filename = self.mktempfile() + ".csv" + pathlib.Path(train_filename).write_text( + inspect.cleandoc( + """ + crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,b,lstat,medv + 0.00632,18,2.31,0,0.538,6.575,65.2,4.09,1,296,15.3,396.9,4.98,24 + 0.02731,0,7.07,0,0.469,6.421,78.9,4.9671,2,242,17.8,396.9,9.14,21.6 + 0.02729,0,7.07,0,0.469,7.185,61.1,4.9671,2,242,17.8,392.83,4.03,34.7 + 0.03237,0,2.18,0,0.458,6.998,45.8,6.0622,3,222,18.7,394.63,2.94,33.4 + 0.06905,0,2.18,0,0.458,7.147,54.2,6.0622,3,222,18.7,396.9,5.33,36.2 + """ + ) + + "\n" + ) + # Create the test data + test_filename = self.mktempfile() + ".csv" + pathlib.Path(test_filename).write_text( + inspect.cleandoc( + """ + crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,b,lstat,medv + 0.02985,0,2.18,0,0.458,6.43,58.7,6.0622,3,222,18.7,394.12,5.21,28.7 + 0.08829,12.5,7.87,0,0.524,6.012,66.6,5.5605,5,311,15.2,395.6,12.43,22.9 + """ + ) + + "\n" + ) + # Create the prediction data + predict_filename = self.mktempfile() + ".csv" + pathlib.Path(predict_filename).write_text( + inspect.cleandoc( + """ + crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,b,lstat,medv + 0.14455,12.5,7.87,0,0.524,6.172,96.1,5.9505,5,311,15.2,396.9,19.15,27.1 + """ + ) + + "\n" + ) + # Features + features = "-model-features def:crim:float:1 def:zn:float:1 def:indus:float:1 def:chas:int:1 def:nox:float:1 def:rm:float:1 def:age:int:1 def:dis:float:1 def:rad:int:1 def:tax:float:1 def:ptratio:float:1 def:b:float:1 def:lstat:float:1".split() + # Train the model + await CLI.cli( + "train", + "-model", + "scikitridge", + *features, + "-model-predict", + "medv", + "-sources", + "training_data=csv", + "-source-filename", + train_filename, + ) + # Assess accuracy + await CLI.cli( + "accuracy", + "-model", + "scikitridge", + *features, + "-model-predict", + "medv", + "-sources", + "test_data=csv", + "-source-filename", + test_filename, + ) + # Ensure JSON output works as expected (#261) + with contextlib.redirect_stdout(self.stdout): + # Make prediction + await CLI._main( + "predict", + "all", + "-model", + "scikitridge", + *features, + "-model-predict", + "medv", + "-sources", + "predict_data=csv", + "-source-filename", + predict_filename, + ) + results = json.loads(self.stdout.getvalue()) + self.assertTrue(isinstance(results, list)) + self.assertTrue(results) + results = results[0] + self.assertIn("prediction", results) + results = results["prediction"] + self.assertIn("value", results) + results = results["value"] + self.assertTrue(results is not None)