diff --git a/.travis.yml b/.travis.yml index d8657a8f..cbd8cf82 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,19 +2,26 @@ language: python python: - "2.7" - - "3.4" - "3.5" + - "3.6" + - "3.7" + +addons: + apt: + packages: + - graphviz + install: - pip install Sphinx sphinx_rtd_theme codecov packaging - "python -c $'import os, packaging.version as version\\nv = version.parse(os.environ.get(\"TRAVIS_TAG\", \"1.0\")).public\\nwith open(\"VERSION\", \"w\") as f: f.write(v)'" - - python setup.py install + - pip install -e .[test] - cd docs - make clean html - cd .. script: - - python setup.py nosetests --with-coverage --cover-package=graphkit + - pytest -v --cov=graphkit deploy: provider: pypi diff --git a/README.md b/README.md index 0e1e95a4..af414020 100644 --- a/README.md +++ b/README.md @@ -51,6 +51,19 @@ print(out) As you can see, any function can be used as an operation in GraphKit, even ones imported from system modules! +For debugging, you may plot the workflow with one of these methods: + +```python + graph.net.plot(show=True) # open a matplotlib window + graph.net.plot("path/to/workflow.png") # supported files: .png .dot .jpg .jpeg .pdf .svg +``` + +> **NOTE**: For plots, `graphviz` must be in your PATH, and `pydot` & `matplotlib` python packages installed. +> You may install both when installing *graphkit* with its `plot` extras: +> ```python +> pip install graphkit[plot] +> ``` + # License Code licensed under the Apache License, Version 2.0 license. See LICENSE file for terms. diff --git a/docs/source/index.rst b/docs/source/index.rst index 5c5e505c..6b5cb690 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -69,6 +69,18 @@ Here's a Python script with an example GraphKit computation graph that produces As you can see, any function can be used as an operation in GraphKit, even ones imported from system modules! +For debugging, you may plot the workflow with one of these methods:: + + graph.net.plot(show=True) # open a matplotlib window + graph.net.plot("path/to/workflow.png") # supported files: .png .dot .jpg .jpeg .pdf .svg + +.. NOTE:: + For plots, ``graphviz`` must be in your PATH, and ``pydot` & ``matplotlib`` python packages installed. + You may install both when installing *graphkit* with its `plot` extras:: + + pip install graphkit[plot] + + License ------- diff --git a/graphkit/base.py b/graphkit/base.py index 1c04e8d5..212de939 100644 --- a/graphkit/base.py +++ b/graphkit/base.py @@ -171,8 +171,35 @@ def set_execution_method(self, method): assert method in options self._execution_method = method - def plot(self, filename=None, show=False): - self.net.plot(filename=filename, show=show) + def plot(self, filename=None, show=False, jupyter=None, + inputs=None, outputs=None, solution=None): + """ + :param str filename: + Write diagram into a file. + Common extensions are ``.png .dot .jpg .jpeg .pdf .svg`` + call :func:`plot.supported_plot_formats()` for more. + :param show: + If it evaluates to true, opens the diagram in a matplotlib window. + If it equals `-1`, it plots but does not open the Window. + :param jupyter: + If it evaluates to true, return an SVG suitable to render + in *jupyter notebook cells* (`ipython` must be installed). + :param inputs: + an optional name list, any nodes in there are plotted + as a "house" + :param outputs: + an optional name list, any nodes in there are plotted + as an "inverted-house" + :param solution: + an optional dict with values to annotate nodes + (currently content not shown, but node drawn as "filled") + + :return: + An instance of the :mod`pydot` graph + + See :func:`graphkit.plot.plot_graph()` for the plot legend and example code. + """ + return self.net.plot(filename, show, jupyter, inputs, outputs, solution) def __getstate__(self): state = Operation.__getstate__(self) diff --git a/graphkit/network.py b/graphkit/network.py index 0df3ddf8..582d0c0a 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -2,12 +2,11 @@ # Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. import time -import os import networkx as nx -from io import StringIO from .base import Operation +from .modifiers import optional class DataPlaceholderNode(str): @@ -375,77 +374,40 @@ def _compute_sequential_method(self, named_inputs, outputs): return {k: cache[k] for k in iter(cache) if k in outputs} - def plot(self, filename=None, show=False): + def plot(self, filename=None, show=False, jupyter=None, + inputs=None, outputs=None, solution=None): """ - Plot the graph. + Plot a *Graphviz* graph and return it, if no other argument provided. - params: :param str filename: - Write the output to a png, pdf, or graphviz dot file. The extension - controls the output format. - - :param boolean show: - If this is set to True, use matplotlib to show the graph diagram - (Default: False) - - :returns: - An instance of the pydot graph - + Write diagram into a file. + Common extensions are ``.png .dot .jpg .jpeg .pdf .svg`` + call :func:`plot.supported_plot_formats()` for more. + :param show: + If it evaluates to true, opens the diagram in a matplotlib window. + If it equals `-1``, it plots but does not open the Window. + :param jupyter: + If it evaluates to true, return an SVG suitable to render + in *jupyter notebook cells* (`ipython` must be installed). + :param inputs: + an optional name list, any nodes in there are plotted + as a "house" + :param outputs: + an optional name list, any nodes in there are plotted + as an "inverted-house" + :param solution: + an optional dict with values to annotate nodes + (currently content not shown, but node drawn as "filled") + + :return: + An instance of the :mod`pydot` graph + + See :func:`graphkit.plot.plot_graph()` for the plot legend and example code. """ - import pydot - import matplotlib.pyplot as plt - import matplotlib.image as mpimg - - assert self.graph is not None - - def get_node_name(a): - if isinstance(a, DataPlaceholderNode): - return a - return a.name - - g = pydot.Dot(graph_type="digraph") - - # draw nodes - for nx_node in self.graph.nodes(): - if isinstance(nx_node, DataPlaceholderNode): - node = pydot.Node(name=nx_node, shape="rect") - else: - node = pydot.Node(name=nx_node.name, shape="circle") - g.add_node(node) - - # draw edges - for src, dst in self.graph.edges(): - src_name = get_node_name(src) - dst_name = get_node_name(dst) - edge = pydot.Edge(src=src_name, dst=dst_name) - g.add_edge(edge) - - # save plot - if filename: - basename, ext = os.path.splitext(filename) - with open(filename, "w") as fh: - if ext.lower() == ".png": - fh.write(g.create_png()) - elif ext.lower() == ".dot": - fh.write(g.to_string()) - elif ext.lower() in [".jpg", ".jpeg"]: - fh.write(g.create_jpeg()) - elif ext.lower() == ".pdf": - fh.write(g.create_pdf()) - elif ext.lower() == ".svg": - fh.write(g.create_svg()) - else: - raise Exception("Unknown file format for saving graph: %s" % ext) - - # display graph via matplotlib - if show: - png = g.create_png() - sio = StringIO(png) - img = mpimg.imread(sio) - plt.imshow(img, aspect="equal") - plt.show() - - return g + from . import plot + + return plot.plot_graph(self.graph, filename, show, jupyter, + self.steps, inputs, outputs, solution) def ready_to_schedule_operation(op, has_executed, graph): @@ -494,3 +456,10 @@ def get_data_node(name, graph): if node == name and isinstance(node, DataPlaceholderNode): return node return None + + +def supported_plot_formats(): + """return automatically all `pydot` extensions withlike ``.png``""" + import pydot + + return [".%s" % f for f in pydot.Dot().formats] diff --git a/graphkit/plot.py b/graphkit/plot.py new file mode 100644 index 00000000..c9db1fe4 --- /dev/null +++ b/graphkit/plot.py @@ -0,0 +1,215 @@ +# Copyright 2016, Yahoo Inc. +# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. + +import io +import os + +from .base import NetworkOperation, Operation +from .modifiers import optional + + +def supported_plot_formats(): + """return automatically all `pydot` extensions withlike ``.png``""" + import pydot + + return [".%s" % f for f in pydot.Dot().formats] + + +def build_pydot(graph, steps=None, inputs=None, outputs=None, solution=None): + """ Build a Graphviz graph """ + import pydot + + assert graph is not None + + def get_node_name(a): + if isinstance(a, Operation): + return a.name + return a + + dot = pydot.Dot(graph_type="digraph") + + # draw nodes + for nx_node in graph.nodes: + kw = {} + if isinstance(nx_node, str): + # Only DeleteInstructions data in steps. + if steps and nx_node in steps: + kw = {"color": "red", "penwidth": 2} + + # SHAPE change if in inputs/outputs. + # tip: https://graphviz.gitlab.io/_pages/doc/info/shapes.html + shape = "rect" + if inputs and outputs and nx_node in inputs and nx_node in outputs: + shape = "hexagon" + else: + if inputs and nx_node in inputs: + shape = "invhouse" + if outputs and nx_node in outputs: + shape = "house" + + # LABEL change from solution. + if solution and nx_node in solution: + kw["style"] = "filled" + kw["fillcolor"] = "gray" + # kw["tooltip"] = nx_node, solution.get(nx_node) + node = pydot.Node(name=nx_node, shape=shape, URL="fdgfdf", **kw) + else: # Operation + kw = {} + shape = "oval" if isinstance(nx_node, NetworkOperation) else "circle" + if nx_node in steps: + kw["style"] = "bold" + node = pydot.Node(name=nx_node.name, shape=shape, **kw) + + dot.add_node(node) + + # draw edges + for src, dst in graph.edges: + src_name = get_node_name(src) + dst_name = get_node_name(dst) + kw = {} + if isinstance(dst, Operation) and any( + n == src and isinstance(n, optional) for n in dst.needs + ): + kw["style"] = "dashed" + edge = pydot.Edge(src=src_name, dst=dst_name, **kw) + dot.add_edge(edge) + + # draw steps sequence + if steps and len(steps) > 1: + it1 = iter(steps) + it2 = iter(steps) + next(it2) + for i, (src, dst) in enumerate(zip(it1, it2), 1): + src_name = get_node_name(src) + dst_name = get_node_name(dst) + edge = pydot.Edge( + src=src_name, + dst=dst_name, + label=str(i), + style="dotted", + color="green", + fontcolor="green", + fontname="bold", + fontsize=18, + penwidth=3, + arrowhead="vee", + ) + dot.add_edge(edge) + + return dot + + +def plot_graph( + graph, + filename=None, + show=False, + jupyter=False, + steps=None, + inputs=None, + outputs=None, + solution=None, +): + """ + Plot a *Graphviz* graph/steps and return it, if no other argument provided. + + Legend: + + + NODES: + + - **circle**: function + - **oval**: subgraph function + - **house**: given input + - **inversed-house**: asked output + - **polygon**: given both as input & asked as output (what?) + - **square**: intermediate data, neither given nor asked. + - **red frame**: delete-instruction, to free up memory. + - **filled**: data node has a value in `solution`, shown in tooltip. + - **thick frame**: function/data node visited. + + ARROWS + + - **solid black arrows**: dependencies (source-data are``need``-ed + by target-operations, sources-operations ``provide`` target-data) + - **dashed black arrows**: optional needs + - **green-dotted arrows**: execution steps labeled in succession + + :param graph: + what to plot + :param str filename: + Write diagram into a file. + Common extensions are ``.png .dot .jpg .jpeg .pdf .svg`` + call :func:`plot.supported_plot_formats()` for more. + :param show: + If it evaluates to true, opens the diagram in a matplotlib window. + If it equals `-1``, it plots but does not open the Window. + :param jupyter: + If it evaluates to true, return an SVG suitable to render + in *jupyter notebook cells* (`ipython` must be installed). + :param steps: + a list of nodes & instructions to overlay on the diagram + :param inputs: + an optional name list, any nodes in there are plotted + as a "house" + :param outputs: + an optional name list, any nodes in there are plotted + as an "inverted-house" + :param solution: + an optional dict with values to annotate nodes + (currently content not shown, but node drawn as "filled") + + :return: + An instance of the :mod`pydot` graph + + **Example:** + + >>> from graphkit import compose, operation + >>> from graphkit.modifiers import optional + + >>> pipeline = compose(name="pipeline")( + ... operation(name="add", needs=["a", "b1"], provides=["ab1"])(add), + ... operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])(lambda a, b=1: a-b), + ... operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add), + ... ) + + >>> inputs = {'a': 1, 'b1': 2} + >>> solution=pipeline(inputs) + >>> pipeline.plot('plot.svg', inputs=inputs, solution=solution, outputs=['asked', 'b1']); + + """ + dot = build_pydot(graph, steps, inputs, outputs, solution) + + # Save plot + # + if filename: + formats = supported_plot_formats() + _basename, ext = os.path.splitext(filename) + if not ext.lower() in formats: + raise ValueError( + "Unknown file format for saving graph: %s" + " File extensions must be one of: %s" % (ext, " ".join(formats)) + ) + + dot.write(filename, format=ext.lower()[1:]) + + ## Return an SVG renderable in jupyter. + # + if jupyter: + from IPython.display import SVG + + dot = SVG(data=dot.create_svg()) + + ## Display graph via matplotlib + # + if show: + import matplotlib.pyplot as plt + import matplotlib.image as mpimg + + png = dot.create_png() + sio = io.BytesIO(png) + img = mpimg.imread(sio) + plt.imshow(img, aspect="equal") + if show != -1: + plt.show() + + return dot diff --git a/setup.py b/setup.py index bd7883f4..654f4e13 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,16 @@ with io.open('graphkit/__init__.py', 'rt', encoding='utf8') as f: version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) +plot_reqs = [ + "ipython; python_version >= '3.5'", # to test jupyter plot. + "matplotlib", # to test plot + "pydot", # to test plot +] +test_reqs = plot_reqs + [ + "pytest", + "pytest-cov", +] + setup( name='graphkit', version=version, @@ -28,11 +38,15 @@ author_email='huyng@yahoo-inc.com', url='http://github.com/yahoo/graphkit', packages=['graphkit'], - install_requires=['networkx'], + install_requires=[ + "networkx; python_version >= '3.5'", + "networkx == 2.2; python_version < '3.5'", + ], extras_require={ - 'plot': ['pydot', 'matplotlib'] + 'plot': plot_reqs, + 'test': test_reqs, }, - tests_require=['numpy'], + tests_require=test_reqs, license='Apache-2.0', keywords=['graph', 'computation graph', 'DAG', 'directed acyclical graph'], classifiers=[ diff --git a/test/test_graphkit.py b/test/test_graphkit.py index bd97b317..7db2e973 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -6,7 +6,8 @@ from pprint import pprint from operator import add -from numpy.testing import assert_raises + +import pytest import graphkit.network as network import graphkit.modifiers as modifiers @@ -180,9 +181,10 @@ def test_pruning_raises_for_bad_output(): # Request two outputs we can compute and one we can't compute. Assert # that this raises a ValueError. - assert_raises(ValueError, net, {'a': 1, 'b': 2, 'c': 3, 'd': 4}, - outputs=['sum1', 'sum3', 'sum4']) - + with pytest.raises(ValueError) as exinfo: + net({'a': 1, 'b': 2, 'c': 3, 'd': 4}, + outputs=['sum1', 'sum3', 'sum4']) + assert exinfo.match('sum4') def test_optional(): # Test that optional() needs work as expected. diff --git a/test/test_plot.py b/test/test_plot.py new file mode 100644 index 00000000..39ad039f --- /dev/null +++ b/test/test_plot.py @@ -0,0 +1,108 @@ +# Copyright 2016, Yahoo Inc. +# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. + +import sys +from operator import add + +import pytest + +from graphkit import base, compose, network, operation, plot +from graphkit.modifiers import optional + + +@pytest.fixture +def pipeline(): + return compose(name="netop")( + operation(name="add", needs=["a", "b1"], provides=["ab1"])(add), + operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])( + lambda a, b=1: a - b + ), + operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add), + ) + + +@pytest.fixture(params=[{"a": 1}, {"a": 1, "b1": 2}]) +def inputs(request): + return {"a": 1, "b1": 2} + + +@pytest.fixture(params=[None, ("a", "b1")]) +def input_names(request): + return request.param + + +@pytest.fixture(params=[None, ["asked", "b1"]]) +def outputs(request): + return request.param + + +@pytest.fixture(params=[None, 1]) +def solution(pipeline, inputs, outputs, request): + return request.param and pipeline(inputs, outputs) + + +###### TEST CASES ####### +## + + +def test_plotting_docstring(): + common_formats = ".png .dot .jpg .jpeg .pdf .svg".split() + for ext in common_formats: + assert ext in plot.plot_graph.__doc__ + assert ext in base.NetworkOperation.plot.__doc__ + assert ext in network.Network.plot.__doc__ + + +def test_plot_formats(pipeline, input_names, outputs, solution, tmp_path): + ## Generate all formats (not needing to save files) + + # ...these are not working on my PC, or travis. + forbidden_formats = ".dia .hpgl .mif .mp .pcl .pic .vtx .xlib".split() + prev_dot = None + for ext in plot.supported_plot_formats(): + if ext not in forbidden_formats: + dot = pipeline.plot(inputs=input_names, outputs=outputs, solution=solution) + assert dot + assert ext == ".jpg" or dot != prev_dot + prev_dot = dot + + +def test_plot_bad_format(pipeline, tmp_path): + with pytest.raises(ValueError, match="Unknown file format") as exinfo: + pipeline.plot(filename="bad.format") + + ## Check help msg lists all siupported formats + for ext in plot.supported_plot_formats(): + assert exinfo.match(ext) + + +def test_plot_write_file(pipeline, tmp_path): + # Try saving a file from one format. + + fpath = tmp_path / "workflow.png" + + dot = pipeline.plot(str(fpath)) + assert fpath.exists() + assert dot + + +def test_plot_matplib(pipeline, tmp_path): + ## Try matplotlib Window, but # without opening a Window. + + if sys.version_info < (3, 5): + # On PY< 3.5 it fails with: + # nose.proxy.TclError: no display name and no $DISPLAY environment variable + # eg https://travis-ci.org/ankostis/graphkit/jobs/593957996 + import matplotlib + + matplotlib.use("Agg") + # do not open window in headless travis + assert pipeline.plot(show=-1) + + +@pytest.mark.skipif(sys.version_info < (3, 5), reason="ipython-7+ dropped PY3.4-") +def test_plot_jupyter(pipeline, tmp_path): + ## Try returned Jupyter SVG. + + dot = pipeline.plot(jupyter=True) + assert "display.SVG" in str(type(dot))