|
| 1 | +# Copyright 2016, Yahoo Inc. |
| 2 | +# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. |
| 3 | + |
| 4 | +from operator import add |
| 5 | + |
| 6 | +import pytest |
| 7 | +import sys |
| 8 | + |
| 9 | +from graphkit import compose, network, operation |
| 10 | +from graphkit.modifiers import optional |
| 11 | + |
| 12 | + |
| 13 | +@pytest.fixture |
| 14 | +def pipeline(): |
| 15 | + return compose(name="netop")( |
| 16 | + operation(name="add", needs=["a", "b1"], provides=["ab1"])(add), |
| 17 | + operation(name="sub", needs=["a", optional("b2")], provides=["ab2"])( |
| 18 | + lambda a, b=1: a - b |
| 19 | + ), |
| 20 | + operation(name="abb", needs=["ab1", "ab2"], provides=["asked"])(add), |
| 21 | + ) |
| 22 | + |
| 23 | + |
| 24 | +@pytest.fixture(params=[{"a": 1}, {"a": 1, "b1": 2}]) |
| 25 | +def inputs(request): |
| 26 | + return {"a": 1, "b1": 2} |
| 27 | + |
| 28 | + |
| 29 | +@pytest.fixture(params=[None, ("a", "b1")]) |
| 30 | +def input_names(request): |
| 31 | + return request.param |
| 32 | + |
| 33 | + |
| 34 | +@pytest.fixture(params=[None, ["asked", "b1"]]) |
| 35 | +def outputs(request): |
| 36 | + return request.param |
| 37 | + |
| 38 | + |
| 39 | +@pytest.fixture(params=[None, 1]) |
| 40 | +def solution(pipeline, inputs, outputs, request): |
| 41 | + return request.param and pipeline(inputs, outputs) |
| 42 | + |
| 43 | + |
| 44 | +###### TEST CASES ####### |
| 45 | +## |
| 46 | + |
| 47 | + |
| 48 | +def test_plotting_docstring(): |
| 49 | + common_formats = ".png .dot .jpg .jpeg .pdf .svg".split() |
| 50 | + for ext in common_formats: |
| 51 | + assert ext in network.plot_graph.__doc__ |
| 52 | + |
| 53 | + |
| 54 | +def test_plot_formats(pipeline, input_names, outputs, solution, tmp_path): |
| 55 | + ## Generate all formats (not needing to save files) |
| 56 | + |
| 57 | + # ...these are not working on my PC, or travis. |
| 58 | + forbidden_formats = ".dia .hpgl .mif .mp .pcl .pic .vtx .xlib".split() |
| 59 | + prev_dot = None |
| 60 | + for ext in network.supported_plot_formats(): |
| 61 | + if ext not in forbidden_formats: |
| 62 | + dot = pipeline.plot(inputs=input_names, outputs=outputs, solution=solution) |
| 63 | + assert dot |
| 64 | + assert ext == ".jpg" or dot != prev_dot |
| 65 | + prev_dot = dot |
| 66 | + |
| 67 | + |
| 68 | +def test_plot_bad_format(pipeline, tmp_path): |
| 69 | + with pytest.raises(ValueError, match="Unknown file format") as exinfo: |
| 70 | + pipeline.plot(filename="bad.format") |
| 71 | + |
| 72 | + ## Check help msg lists all siupported formats |
| 73 | + for ext in network.supported_plot_formats(): |
| 74 | + assert exinfo.match(ext) |
| 75 | + |
| 76 | + |
| 77 | +def test_plot_write_file(pipeline, tmp_path): |
| 78 | + # Try saving a file from one format. |
| 79 | + |
| 80 | + fpath = tmp_path / "workflow.png" |
| 81 | + |
| 82 | + dot = pipeline.plot(str(fpath)) |
| 83 | + assert fpath.exists() |
| 84 | + assert dot |
| 85 | + |
| 86 | + |
| 87 | +def test_plot_matplib(pipeline, tmp_path): |
| 88 | + ## Try matplotlib Window, but # without opening a Window. |
| 89 | + |
| 90 | + if sys.version_info < (3, 5): |
| 91 | + # On PY< 3.5 it fails with: |
| 92 | + # nose.proxy.TclError: no display name and no $DISPLAY environment variable |
| 93 | + # eg https://travis-ci.org/ankostis/graphkit/jobs/593957996 |
| 94 | + import matplotlib |
| 95 | + |
| 96 | + matplotlib.use("Agg") |
| 97 | + # do not open window in headless travis |
| 98 | + assert pipeline.plot(show=-1) |
| 99 | + |
| 100 | + |
| 101 | +@pytest.mark.skipif(sys.version_info < (3, 5), reason="ipython-7+ dropped PY3.4-") |
| 102 | +def test_plot_jupyter(pipeline, tmp_path): |
| 103 | + ## Try returned Jupyter SVG. |
| 104 | + |
| 105 | + dot = pipeline.plot(jupyter=True) |
| 106 | + assert "display.SVG" in str(type(dot)) |
0 commit comments