diff --git a/CHANGELOG.md b/CHANGELOG.md
index 28f231e9fd..d24194ed8f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- GaussianProcessRegressor
- OrthogonalMatchingPursuit
- Lars
+ - Ridge
- `AsyncExitStackTestCase` which instantiates and enters async and non-async
`contextlib` exit stacks. Provides temporary file creation.
### Changed
diff --git a/docs/plugins/dffml_model.rst b/docs/plugins/dffml_model.rst
index f69346403d..67874b549f 100644
--- a/docs/plugins/dffml_model.rst
+++ b/docs/plugins/dffml_model.rst
@@ -445,6 +445,8 @@ Predicting with trained model:
| | OrthogonalMatchingPursuit | scikitomp | `scikitomp `_ |
| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| | Lars | scikitlars | `scikitlars `_ |
+| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+| | Ridge | scikitridge | `scikitridge `_ |
+----------------+-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Classification | KNeighborsClassifier | scikitknn | `scikitknn `_ |
| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/model/scikit/dffml_model_scikit/__init__.py b/model/scikit/dffml_model_scikit/__init__.py
index aa4e0eff65..5ed1ed0ba9 100644
--- a/model/scikit/dffml_model_scikit/__init__.py
+++ b/model/scikit/dffml_model_scikit/__init__.py
@@ -70,6 +70,8 @@
| | OrthogonalMatchingPursuit | scikitomp | `scikitomp `_ |
| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| | Lars | scikitlars | `scikitlars `_ |
+| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
+| | Ridge | scikitridge | `scikitridge `_ |
+----------------+-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
| Classification | KNeighborsClassifier | scikitknn | `scikitknn `_ |
| +-------------------------------+----------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
diff --git a/model/scikit/dffml_model_scikit/scikit_models.py b/model/scikit/dffml_model_scikit/scikit_models.py
index 97cca33384..a0fb695ef2 100644
--- a/model/scikit/dffml_model_scikit/scikit_models.py
+++ b/model/scikit/dffml_model_scikit/scikit_models.py
@@ -40,6 +40,7 @@
RANSACRegressor,
OrthogonalMatchingPursuit,
Lars,
+ Ridge,
)
from dffml.util.cli.arg import Arg
@@ -165,6 +166,7 @@ class NoDefaultValue:
OrthogonalMatchingPursuit,
applicable_features,
),
+ ("scikitridge", "Ridge", Ridge, applicable_features),
("scikitlars", "Lars", Lars, applicable_features),
]:
diff --git a/model/scikit/setup.py b/model/scikit/setup.py
index 92c9ef4849..9a99684ac3 100644
--- a/model/scikit/setup.py
+++ b/model/scikit/setup.py
@@ -104,6 +104,7 @@
f"scikitgpr = {IMPORT_NAME}.scikit_models:GaussianProcessRegressorModel",
f"scikitomp = {IMPORT_NAME}.scikit_models:OrthogonalMatchingPursuitModel",
f"scikitlars = {IMPORT_NAME}.scikit_models:LarsModel",
+ f"scikitridge = {IMPORT_NAME}.scikit_models:RidgeModel",
]
},
)
diff --git a/model/scikit/tests/test_scikit.py b/model/scikit/tests/test_scikit.py
index 5c577f9092..d02c769965 100644
--- a/model/scikit/tests/test_scikit.py
+++ b/model/scikit/tests/test_scikit.py
@@ -103,9 +103,9 @@ async def test_02_predict(self):
elif self.MODEL_TYPE is "REGRESSION":
correct = FEATURE_DATA_REGRESSION[int(repo.src_url)][3]
self.assertGreater(
- prediction, correct - (correct * 0.20)
+ prediction, correct - (correct * 0.40)
)
- self.assertLess(prediction, correct + (correct * 0.20))
+ self.assertLess(prediction, correct + (correct * 0.40))
FEATURE_DATA_CLASSIFICATION = [
@@ -208,6 +208,7 @@ async def test_02_predict(self):
"GaussianProcessRegressor",
"OrthogonalMatchingPursuit",
"Lars",
+ "Ridge",
]
for clf in CLASSIFIERS:
diff --git a/tests/integration/test_models.py b/tests/integration/test_models.py
index bbabbcfa6f..310b649d12 100644
--- a/tests/integration/test_models.py
+++ b/tests/integration/test_models.py
@@ -121,3 +121,99 @@ async def test_run(self):
self.assertIn("value", results)
results = results["value"]
self.assertEqual(4, results)
+
+
+class TestScikitRegression(IntegrationCLITestCase):
+ async def test_run(self):
+ self.required_plugins("dffml-model-scikit")
+ # Create the training data
+ train_filename = self.mktempfile() + ".csv"
+ pathlib.Path(train_filename).write_text(
+ inspect.cleandoc(
+ """
+ crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,b,lstat,medv
+ 0.00632,18,2.31,0,0.538,6.575,65.2,4.09,1,296,15.3,396.9,4.98,24
+ 0.02731,0,7.07,0,0.469,6.421,78.9,4.9671,2,242,17.8,396.9,9.14,21.6
+ 0.02729,0,7.07,0,0.469,7.185,61.1,4.9671,2,242,17.8,392.83,4.03,34.7
+ 0.03237,0,2.18,0,0.458,6.998,45.8,6.0622,3,222,18.7,394.63,2.94,33.4
+ 0.06905,0,2.18,0,0.458,7.147,54.2,6.0622,3,222,18.7,396.9,5.33,36.2
+ """
+ )
+ + "\n"
+ )
+ # Create the test data
+ test_filename = self.mktempfile() + ".csv"
+ pathlib.Path(test_filename).write_text(
+ inspect.cleandoc(
+ """
+ crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,b,lstat,medv
+ 0.02985,0,2.18,0,0.458,6.43,58.7,6.0622,3,222,18.7,394.12,5.21,28.7
+ 0.08829,12.5,7.87,0,0.524,6.012,66.6,5.5605,5,311,15.2,395.6,12.43,22.9
+ """
+ )
+ + "\n"
+ )
+ # Create the prediction data
+ predict_filename = self.mktempfile() + ".csv"
+ pathlib.Path(predict_filename).write_text(
+ inspect.cleandoc(
+ """
+ crim,zn,indus,chas,nox,rm,age,dis,rad,tax,ptratio,b,lstat,medv
+ 0.14455,12.5,7.87,0,0.524,6.172,96.1,5.9505,5,311,15.2,396.9,19.15,27.1
+ """
+ )
+ + "\n"
+ )
+ # Features
+ features = "-model-features def:crim:float:1 def:zn:float:1 def:indus:float:1 def:chas:int:1 def:nox:float:1 def:rm:float:1 def:age:int:1 def:dis:float:1 def:rad:int:1 def:tax:float:1 def:ptratio:float:1 def:b:float:1 def:lstat:float:1".split()
+ # Train the model
+ await CLI.cli(
+ "train",
+ "-model",
+ "scikitridge",
+ *features,
+ "-model-predict",
+ "medv",
+ "-sources",
+ "training_data=csv",
+ "-source-filename",
+ train_filename,
+ )
+ # Assess accuracy
+ await CLI.cli(
+ "accuracy",
+ "-model",
+ "scikitridge",
+ *features,
+ "-model-predict",
+ "medv",
+ "-sources",
+ "test_data=csv",
+ "-source-filename",
+ test_filename,
+ )
+ # Ensure JSON output works as expected (#261)
+ with contextlib.redirect_stdout(self.stdout):
+ # Make prediction
+ await CLI._main(
+ "predict",
+ "all",
+ "-model",
+ "scikitridge",
+ *features,
+ "-model-predict",
+ "medv",
+ "-sources",
+ "predict_data=csv",
+ "-source-filename",
+ predict_filename,
+ )
+ results = json.loads(self.stdout.getvalue())
+ self.assertTrue(isinstance(results, list))
+ self.assertTrue(results)
+ results = results[0]
+ self.assertIn("prediction", results)
+ results = results["prediction"]
+ self.assertIn("value", results)
+ results = results["value"]
+ self.assertTrue(results is not None)