Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Input validation using operations #454

Merged
merged 10 commits into from
Mar 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## [Unreleased]
### Added
- Docstrings and doctestable examples to `record.py`.
- Inputs can be validated using operations
- `validate` parameter in `Input` takes `Operation.instance_name`
### Fixed
- New model tutorial mentions file paths that should be edited.

Expand Down
4 changes: 4 additions & 0 deletions dffml/df/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,7 @@ class NotOpImp(Exception):

class InputValidationError(Exception):
pass


class ValidatorMissing(Exception):
pass
115 changes: 103 additions & 12 deletions dffml/df/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
Set,
)

from .exceptions import ContextNotPresent, DefinitionNotInContext
from .exceptions import (
ContextNotPresent,
DefinitionNotInContext,
ValidatorMissing,
)
from .types import Input, Parameter, Definition, Operation, Stage, DataFlow
from .base import (
OperationException,
Expand Down Expand Up @@ -122,6 +126,26 @@ async def inputs(self) -> AsyncIterator[Input]:
for item in self.__inputs:
yield item

def remove_input(self, item: Input):
for x in self.__inputs[:]:
if x.uid == item.uid:
self.__inputs.remove(x)
break

def remove_unvalidated_inputs(self) -> "MemoryInputSet":
"""
Removes `unvalidated` inputs from internal list and returns the same.
"""
unvalidated_inputs = []
for x in self.__inputs[:]:
if not x.validated:
unvalidated_inputs.append(x)
self.__inputs.remove(x)
unvalidated_input_set = MemoryInputSet(
MemoryInputSetConfig(ctx=self.ctx, inputs=unvalidated_inputs)
)
return unvalidated_input_set


class MemoryParameterSetConfig(NamedTuple):
ctx: BaseInputSetContext
Expand Down Expand Up @@ -249,15 +273,19 @@ async def add(self, input_set: BaseInputSet):
handle_string = handle.as_string()
# TODO These ctx.add calls should probably happen after inputs are in
# self.ctxhd

# remove unvalidated inputs
unvalidated_input_set = input_set.remove_unvalidated_inputs()

# If the context for this input set does not exist create a
# NotificationSet for it to notify the orchestrator
if not handle_string in self.input_notification_set:
self.input_notification_set[handle_string] = NotificationSet()
async with self.ctx_notification_set() as ctx:
await ctx.add(input_set.ctx)
await ctx.add((None, input_set.ctx))
# Add the input set to the incoming inputs
async with self.input_notification_set[handle_string]() as ctx:
await ctx.add(input_set)
await ctx.add((unvalidated_input_set, input_set))
# Associate inputs with their context handle grouped by definition
async with self.ctxhd_lock:
# Create dict for handle_string if not present
Expand Down Expand Up @@ -921,6 +949,7 @@ async def run_dispatch(
octx: BaseOrchestratorContext,
operation: Operation,
parameter_set: BaseParameterSet,
set_valid: bool = True,
):
"""
Run an operation in the background and add its outputs to the input
Expand Down Expand Up @@ -952,14 +981,14 @@ async def run_dispatch(
if not key in expand:
output = [output]
for value in output:
inputs.append(
Input(
value=value,
definition=operation.outputs[key],
parents=parents,
origin=(operation.instance_name, key),
)
new_input = Input(
value=value,
definition=operation.outputs[key],
parents=parents,
origin=(operation.instance_name, key),
)
new_input.validated = set_valid
inputs.append(new_input)
except KeyError as error:
raise KeyError(
"Value %s missing from output:definition mapping %s(%s)"
Expand Down Expand Up @@ -1020,6 +1049,38 @@ async def operations_parameter_set_pairs(
):
yield operation, parameter_set

async def validator_target_set_pairs(
self,
octx: BaseOperationNetworkContext,
rctx: BaseRedundancyCheckerContext,
ctx: BaseInputSetContext,
dataflow: DataFlow,
unvalidated_input_set: BaseInputSet,
):
async for unvalidated_input in unvalidated_input_set.inputs():
validator_instance_name = unvalidated_input.definition.validate
validator = dataflow.validators.get(validator_instance_name, None)
if validator is None:
raise ValidatorMissing(
"Validator with instance_name {validator_instance_name} not found"
)
# There is only one `input` in `validators`
input_name, input_definition = list(validator.inputs.items())[0]
parameter = Parameter(
key=input_name,
value=unvalidated_input.value,
origin=unvalidated_input,
definition=input_definition,
)
parameter_set = MemoryParameterSet(
MemoryParameterSetConfig(ctx=ctx, parameters=[parameter])
)
async for parameter_set, exists in rctx.exists(
validator, parameter_set
):
if not exists:
yield validator, parameter_set


@entrypoint("memory")
class MemoryOperationImplementationNetwork(
Expand Down Expand Up @@ -1382,17 +1443,44 @@ async def run_operations_for_ctx(
task.print_stack(file=output)
self.logger.error("%s", output.getvalue().rstrip())
output.close()

elif task is input_set_enters_network:
(
more,
new_input_sets,
) = input_set_enters_network.result()
for new_input_set in new_input_sets:
for (
unvalidated_input_set,
new_input_set,
) in new_input_sets:
async for operation, parameter_set in self.nctx.validator_target_set_pairs(
self.octx,
self.rctx,
ctx,
self.config.dataflow,
unvalidated_input_set,
):
await self.rctx.add(
operation, parameter_set
) # is this required here?
dispatch_operation = await self.nctx.dispatch(
self, operation, parameter_set
)
dispatch_operation.operation = operation
dispatch_operation.parameter_set = (
parameter_set
)
tasks.add(dispatch_operation)
self.logger.debug(
"[%s]: dispatch operation: %s",
ctx_str,
operation.instance_name,
)
# forward inputs to subflow
await self.forward_inputs_to_subflow(
[x async for x in new_input_set.inputs()]
)
# Identify which operations have complete contextually
# Identify which operations have completed contextually
# appropriate input sets which haven't been run yet
async for operation, parameter_set in self.nctx.operations_parameter_set_pairs(
self.ictx,
Expand All @@ -1402,6 +1490,9 @@ async def run_operations_for_ctx(
self.config.dataflow,
new_input_set=new_input_set,
):
# Validation operations shouldn't be run here
if operation.validator:
continue
# Add inputs and operation to redundancy checker before
# dispatch
await self.rctx.add(operation, parameter_set)
Expand Down
18 changes: 14 additions & 4 deletions dffml/df/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ class Operation(NamedTuple, Entrypoint):
conditions: Optional[List[Definition]] = []
expand: Optional[List[str]] = []
instance_name: Optional[str] = None
validator: bool = False

def export(self):
exported = {
Expand Down Expand Up @@ -270,11 +271,13 @@ def __init__(
definition: Definition,
parents: Optional[List["Input"]] = None,
origin: Optional[Union[str, Tuple[Operation, str]]] = "seed",
validated: bool = True,
*,
uid: Optional[str] = "",
):
# TODO Add optional parameter Input.target which specifies the operation
# instance name this Input is intended for.
self.validated = validated
if parents is None:
parents = []
if definition.spec is not None:
Expand All @@ -288,7 +291,11 @@ def __init__(
elif isinstance(value, dict):
value = definition.spec(**value)
if definition.validate is not None:
value = definition.validate(value)
if callable(definition.validate):
value = definition.validate(value)
# if validate is a string (operation.instance_name) set `not validated`
elif isinstance(definition.validate, str):
self.validated = False
self.value = value
self.definition = definition
self.parents = parents
Expand Down Expand Up @@ -424,6 +431,8 @@ def __post_init__(self):
self.by_origin = {}
if self.implementations is None:
self.implementations = {}
self.validators = {} # Maps `validator` ops instance_name to op

# Allow callers to pass in functions decorated with op. Iterate over the
# given operations and replace any which have been decorated with their
# operation. Add the implementation to our dict of implementations.
Expand Down Expand Up @@ -451,9 +460,10 @@ def __post_init__(self):
self.operations[instance_name] = operation
value = operation
# Make sure every operation has the correct instance name
self.operations[instance_name] = value._replace(
instance_name=instance_name
)
value = value._replace(instance_name=instance_name)
self.operations[instance_name] = value
if value.validator:
self.validators[instance_name] = value
# Grab all definitions from operations
operations = list(self.operations.values())
definitions = list(
Expand Down
2 changes: 1 addition & 1 deletion examples/shouldi/tests/test_npm_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class TestRunNPM_AuditOp(AsyncTestCase):
"36b3ce51780ee6ea8dcec266c9d09e3a00198868ba1b041569950b82cf45884da0c47ec354dd8514022169849dfe8b7c",
)
async def test_run(self, npm_audit, javascript_algo):
with prepend_to_path(npm_audit / "bin",):
with prepend_to_path(npm_audit / "bin"):
results = await run_npm_audit(
str(
javascript_algo
Expand Down
2 changes: 1 addition & 1 deletion model/scikit/dffml_model_scikit/scikit_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def applicable_features(self, features):
field(
"Directory where state should be saved",
default=pathlib.Path(
"~", ".cache", "dffml", f"scikit-{entry_point_name}",
"~", ".cache", "dffml", f"scikit-{entry_point_name}"
),
),
),
Expand Down
2 changes: 1 addition & 1 deletion scripts/docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def gen_docs(

def fake_getpwuid(uid):
return pwd.struct_passwd(
("user", "x", uid, uid, "", "/home/user", "/bin/bash",)
("user", "x", uid, uid, "", "/home/user", "/bin/bash")
)


Expand Down
49 changes: 49 additions & 0 deletions tests/test_types.py → tests/test_input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ def pie_validation(x):
ShapeName = Definition(
name="shape_name", primitive="str", validate=lambda x: x.upper()
)
SHOUTIN = Definition(
name="shout_in", primitive="str", validate="validate_shout_instance"
)
SHOUTOUT = Definition(name="shout_out", primitive="str")


@op(
Expand All @@ -35,6 +39,20 @@ async def get_circle(name: str, radius: float, pie: float):
}


@op(
inputs={"shout_in": SHOUTIN},
outputs={"shout_in_validated": SHOUTIN},
validator=True,
)
def validate_shouts(shout_in):
return {"shout_in_validated": shout_in + "_validated"}


@op(inputs={"shout_in": SHOUTIN}, outputs={"shout_out": SHOUTOUT})
def echo_shout(shout_in):
return {"shout_out": shout_in}


class TestDefintion(AsyncTestCase):
async def setUp(self):
self.dataflow = DataFlow(
Expand Down Expand Up @@ -80,3 +98,34 @@ async def test_validation_error(self):
]
}
pass

async def test_vaildation_by_op(self):
test_dataflow = DataFlow(
operations={
"validate_shout_instance": validate_shouts.op,
"echo_shout": echo_shout.op,
"get_single": GetSingle.imp.op,
},
seed=[
Input(
value=[echo_shout.op.outputs["shout_out"].name],
definition=GetSingle.op.inputs["spec"],
)
],
implementations={
validate_shouts.op.name: validate_shouts.imp,
echo_shout.op.name: echo_shout.imp,
},
)
test_inputs = {
"TestShoutOut": [
Input(value="validation_status:", definition=SHOUTIN)
]
}
async with MemoryOrchestrator.withconfig({}) as orchestrator:
async with orchestrator(test_dataflow) as octx:
async for ctx_str, results in octx.run(test_inputs):
self.assertIn("shout_out", results)
self.assertEqual(
results["shout_out"], "validation_status:_validated"
)