Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 1dc6181

Browse files
authored
df: types: Add validate to definition
* Auto validate Inputs via their validate parameter in their Definition Fixes: #349
1 parent 93e1e9a commit 1dc6181

File tree

14 files changed

+109
-34
lines changed

14 files changed

+109
-34
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2828
- MySQL connector
2929
- Documented style for imports.
3030
- Documented use of numpy docstrings.
31+
- `Inputs` can now be sanitized using function passed in `validate` parameter
3132
- Helper utilities to take callables with numpy style docstrings and
3233
create config classes out of them using `make_config`.
3334
### Changed

dffml/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def _fromdict(cls, **kwargs):
197197
"config": {
198198
key: value
199199
if is_config_dict(value)
200-
else {"arg": value, "config": {},}
200+
else {"arg": value, "config": {}}
201201
for key, value in config.items()
202202
},
203203
}

dffml/db/base.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,7 @@
22
import inspect
33
import functools
44
import collections
5-
from typing import (
6-
Any,
7-
List,
8-
Optional,
9-
Dict,
10-
Tuple,
11-
Union,
12-
AsyncIterator,
13-
)
5+
from typing import Any, List, Optional, Dict, Tuple, Union, AsyncIterator
146

157
from dffml.df.base import BaseDataFlowObject, BaseDataFlowObjectContext
168
from dffml.util.entrypoint import base_entry_point

dffml/db/sqlite.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ class SqliteDatabaseConfig:
1616

1717
class SqliteDatabaseContext(SQLDatabaseContext):
1818
async def create_table(
19-
self, table_name: str, cols: Dict[str, str],
19+
self, table_name: str, cols: Dict[str, str]
2020
) -> None:
2121
query = self.create_table_query(table_name, cols)
2222
self.logger.debug(query)
2323
self.parent.cursor.execute(query)
2424

25-
async def insert(self, table_name: str, data: Dict[str, Any],) -> None:
25+
async def insert(self, table_name: str, data: Dict[str, Any]) -> None:
2626
query, query_values = self.insert_query(table_name, data)
2727
async with self.parent.lock:
2828
with self.parent.db:

dffml/df/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,7 @@ class DefinitionNotInContext(Exception):
1212

1313
class NotOpImp(Exception):
1414
pass
15+
16+
17+
class InputValidationError(Exception):
18+
pass

dffml/df/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class Definition(NamedTuple):
3737
lock: bool = False
3838
# spec is a NamedTuple which could be populated via a dict
3939
spec: NamedTuple = None
40+
# validate property will be a callable (function or lambda) which returns
41+
# the sanitized version of the value
42+
validate: Callable[[Any], Any] = None
4043

4144
def __repr__(self):
4245
return self.name
@@ -273,6 +276,8 @@ def __init__(
273276
parents = []
274277
if isinstance(value, dict) and definition.spec is not None:
275278
value = definition.spec(**value)
279+
if definition.validate is not None:
280+
value = definition.validate(value)
276281
self.value = value
277282
self.definition = definition
278283
self.parents = parents

feature/git/dffml_feature_git/feature/operations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ async def git_repo_checkout(repo: Dict[str, str], commit: str):
119119
await check_output("git", "checkout", commit, cwd=repo.directory)
120120
return {
121121
"repo": GitRepoCheckedOutSpec(
122-
URL=repo.URL, directory=repo.directory, commit=commit,
122+
URL=repo.URL, directory=repo.directory, commit=commit
123123
)
124124
}
125125

model/scikit/dffml_model_scikit/scikit_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,7 @@ async def predict(
224224
predict = np.array(df)
225225
prediction = predictor(predict)
226226
self.logger.debug(
227-
"Predicted cluster for {}: {}".format(predict, prediction,)
227+
"Predicted cluster for {}: {}".format(predict, prediction)
228228
)
229229
repo.predicted(prediction[0], self.confidence)
230230
yield repo

model/scikit/dffml_model_scikit/scikit_models.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ def applicable_features(self, features):
120120
QuadraticDiscriminantAnalysis,
121121
applicable_features,
122122
),
123-
("scikitlr", "LinearRegression", LinearRegression, applicable_features,),
123+
("scikitlr", "LinearRegression", LinearRegression, applicable_features),
124124
(
125125
"scikitlor",
126126
"LogisticRegression",
@@ -139,17 +139,12 @@ def applicable_features(self, features):
139139
ExtraTreesClassifier,
140140
applicable_features,
141141
),
142-
(
143-
"scikitbgc",
144-
"BaggingClassifier",
145-
BaggingClassifier,
146-
applicable_features,
147-
),
142+
("scikitbgc", "BaggingClassifier", BaggingClassifier, applicable_features),
148143
("scikiteln", "ElasticNet", ElasticNet, applicable_features),
149-
("scikitbyr", "BayesianRidge", BayesianRidge, applicable_features,),
150-
("scikitlas", "Lasso", Lasso, applicable_features,),
151-
("scikitard", "ARDRegression", ARDRegression, applicable_features,),
152-
("scikitrsc", "RANSACRegressor", RANSACRegressor, applicable_features,),
144+
("scikitbyr", "BayesianRidge", BayesianRidge, applicable_features),
145+
("scikitlas", "Lasso", Lasso, applicable_features),
146+
("scikitard", "ARDRegression", ARDRegression, applicable_features),
147+
("scikitrsc", "RANSACRegressor", RANSACRegressor, applicable_features),
153148
("scikitbnb", "BernoulliNB", BernoulliNB, applicable_features),
154149
("scikitmnb", "MultinomialNB", MultinomialNB, applicable_features),
155150
(
@@ -176,8 +171,8 @@ def applicable_features(self, features):
176171
OrthogonalMatchingPursuit,
177172
applicable_features,
178173
),
179-
("scikitridge", "Ridge", Ridge, applicable_features,),
180-
("scikitlars", "Lars", Lars, applicable_features,),
174+
("scikitridge", "Ridge", Ridge, applicable_features),
175+
("scikitlars", "Lars", Lars, applicable_features),
181176
("scikitkmeans", "KMeans", KMeans, applicable_features),
182177
("scikitbirch", "Birch", Birch, applicable_features),
183178
(

source/mysql/dffml_source_mysql/db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class MySQLDatabaseContext(SQLDatabaseContext):
2323
BIND_DECLARATION: str = "%s"
2424

2525
async def create_table(
26-
self, table_name: str, cols: Dict[str, str],
26+
self, table_name: str, cols: Dict[str, str]
2727
) -> None:
2828
query = self.create_table_query(table_name, cols)
2929
self.logger.debug(query)

source/mysql/dffml_source_mysql/source.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,7 @@ async def update(self, repo: Repo):
5151
self.logger.debug("update: %s", await self.repo(repo.key))
5252

5353
def convert_to_repo(self, result):
54-
modified_repo = {
55-
"key": "",
56-
"data": {"features": {}, "prediction": {}},
57-
}
54+
modified_repo = {"key": "", "data": {"features": {}, "prediction": {}}}
5855
for key, value in result.items():
5956
if key.startswith("feature_"):
6057
modified_repo["data"]["features"][

tests/integration/test_service_dev.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import inspect
88
import pathlib
99
import unittest.mock
10-
1110
from dffml.df.types import DataFlow
1211
from dffml.cli.cli import CLI
1312
from dffml.service.dev import Develop

tests/test_types.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from dffml.df.base import op
2+
from dffml.df.types import DataFlow, Input, Definition
3+
from dffml.operation.output import GetSingle
4+
from dffml.util.asynctestcase import AsyncTestCase
5+
from dffml.df.memory import MemoryOrchestrator
6+
from dffml.operation.mapping import MAPPING
7+
from dffml.df.exceptions import InputValidationError
8+
9+
10+
def pie_validation(x):
11+
if x == 3.14:
12+
return x
13+
raise InputValidationError()
14+
15+
16+
Pie = Definition(name="pie", primitive="float", validate=pie_validation)
17+
Radius = Definition(name="radius", primitive="float")
18+
Area = Definition(name="area", primitive="float")
19+
ShapeName = Definition(
20+
name="shape_name", primitive="str", validate=lambda x: x.upper()
21+
)
22+
23+
24+
@op(
25+
inputs={"name": ShapeName, "radius": Radius, "pie": Pie},
26+
outputs={"shape": MAPPING},
27+
)
28+
async def get_circle(name: str, radius: float, pie: float):
29+
return {
30+
"shape": {
31+
"name": name,
32+
"radius": radius,
33+
"area": pie * radius * radius,
34+
}
35+
}
36+
37+
38+
class TestDefintion(AsyncTestCase):
39+
async def setUp(self):
40+
self.dataflow = DataFlow(
41+
operations={
42+
"get_circle": get_circle.op,
43+
"get_single": GetSingle.imp.op,
44+
},
45+
seed=[
46+
Input(
47+
value=[get_circle.op.outputs["shape"].name],
48+
definition=GetSingle.op.inputs["spec"],
49+
)
50+
],
51+
implementations={"get_circle": get_circle.imp},
52+
)
53+
54+
async def test_validate(self):
55+
test_inputs = {
56+
"area": [
57+
Input(value="unitcircle", definition=ShapeName),
58+
Input(value=1, definition=Radius),
59+
Input(value=3.14, definition=Pie),
60+
]
61+
}
62+
async with MemoryOrchestrator.withconfig({}) as orchestrator:
63+
async with orchestrator(self.dataflow) as octx:
64+
async for ctx_str, results in octx.run(test_inputs):
65+
self.assertIn("mapping", results)
66+
results = results["mapping"]
67+
self.assertEqual(results["name"], "UNITCIRCLE")
68+
self.assertEqual(results["area"], 3.14)
69+
self.assertEqual(results["radius"], 1)
70+
71+
async def test_validation_error(self):
72+
with self.assertRaises(InputValidationError):
73+
test_inputs = {
74+
"area": [
75+
Input(value="unitcircle", definition=ShapeName),
76+
Input(value=1, definition=Radius),
77+
Input(
78+
value=4, definition=Pie
79+
), # this should raise validation eror
80+
]
81+
}
82+
pass

tests/util/double_ret.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
def double_ret(
5-
super_cool_arg, *, other_very_cool_arg: Optional[Dict[str, Any]] = None,
5+
super_cool_arg, *, other_very_cool_arg: Optional[Dict[str, Any]] = None
66
) -> Tuple[str, Tuple]:
77
"""
88
This is the short description.

0 commit comments

Comments
 (0)