Skip to content

Make fetcher and resolver configurable. #240

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 54 additions & 20 deletions cwltool/load_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,23 +6,33 @@
import logging
import re
import urlparse
from schema_salad.ref_resolver import Loader

from schema_salad.ref_resolver import Loader, Fetcher, DefaultFetcher
import schema_salad.validate as validate
from schema_salad.validate import ValidationException
import schema_salad.schema as schema
import requests

from typing import Any, AnyStr, Callable, cast, Dict, Text, Tuple, Union

from avro.schema import Names

from . import update
from . import process
from .process import Process, shortname
from .errors import WorkflowException
from typing import Any, AnyStr, Callable, cast, Dict, Text, Tuple, Union

_logger = logging.getLogger("cwltool")

def fetch_document(argsworkflow, resolver=None):
# type: (Union[Text, dict[Text, Any]], Any) -> Tuple[Loader, Dict[Text, Any], Text]
def fetch_document(argsworkflow, # type: Union[Text, dict[Text, Any]]
resolver=None, # type: Callable[[Loader, Union[Text, dict[Text, Any]]], Text]
fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher]
):
# type: (...) -> Tuple[Loader, Dict[Text, Any], Text]
"""Retrieve a CWL document."""
document_loader = Loader({"cwl": "https://w3id.org/cwl/cwl#", "id": "@id"})

document_loader = Loader({"cwl": "https://w3id.org/cwl/cwl#", "id": "@id"},
fetcher_constructor=fetcher_constructor)

uri = None # type: Text
workflowobj = None # type: Dict[Text, Any]
Expand Down Expand Up @@ -95,16 +105,23 @@ def _convert_stdstreams_to_files(workflowobj):
for entry in workflowobj:
_convert_stdstreams_to_files(entry)

def validate_document(document_loader, workflowobj, uri,
enable_dev=False, strict=True, preprocess_only=False):
# type: (Loader, Dict[Text, Any], Text, bool, bool, bool) -> Tuple[Loader, Names, Union[Dict[Text, Any], List[Dict[Text, Any]]], Dict[Text, Any], Text]
def validate_document(document_loader, # type: Loader
workflowobj, # type: Dict[Text, Any]
uri, # type: Text
enable_dev=False, # type: bool
strict=True, # type: bool
preprocess_only=False, # type: bool
fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher]
):
# type: (...) -> Tuple[Loader, Names, Union[Dict[Text, Any], List[Dict[Text, Any]]], Dict[Text, Any], Text]
"""Validate a CWL document."""

jobobj = None
if "cwl:tool" in workflowobj:
jobobj, _ = document_loader.resolve_all(workflowobj, uri)
uri = urlparse.urljoin(uri, workflowobj["https://w3id.org/cwl/cwl#tool"])
del cast(dict, jobobj)["https://w3id.org/cwl/cwl#tool"]
workflowobj = fetch_document(uri)[1]
workflowobj = fetch_document(uri, fetcher_constructor=fetcher_constructor)[1]

if isinstance(workflowobj, list):
workflowobj = {
Expand All @@ -130,12 +147,16 @@ def validate_document(document_loader, workflowobj, uri,
workflowobj["$graph"] = workflowobj["@graph"]
del workflowobj["@graph"]

(document_loader, avsc_names) = \
(sch_document_loader, avsc_names) = \
process.get_schema(workflowobj["cwlVersion"])[:2]

if isinstance(avsc_names, Exception):
raise avsc_names

document_loader = Loader(sch_document_loader.ctx, schemagraph=sch_document_loader.graph,
idx=document_loader.idx, cache=sch_document_loader.cache,
fetcher_constructor=fetcher_constructor)

workflowobj["id"] = fileuri
processobj, metadata = document_loader.resolve_all(workflowobj, fileuri)
if not isinstance(processobj, (dict, list)):
Expand Down Expand Up @@ -165,8 +186,14 @@ def validate_document(document_loader, workflowobj, uri,
return document_loader, avsc_names, processobj, metadata, uri


def make_tool(document_loader, avsc_names, metadata, uri, makeTool, kwargs):
# type: (Loader, Names, Dict[Text, Any], Text, Callable[..., Process], Dict[AnyStr, Any]) -> Process
def make_tool(document_loader, # type: Loader
avsc_names, # type: Names
metadata, # type: Dict[Text, Any]
uri, # type: Text
makeTool, # type: Callable[..., Process]
kwargs # type: dict
):
# type: (...) -> Process
"""Make a Python CWL object."""
resolveduri = document_loader.resolve_ref(uri)[0]

Expand All @@ -179,8 +206,10 @@ def make_tool(document_loader, avsc_names, metadata, uri, makeTool, kwargs):
"one of #%s" % ", #".join(
urlparse.urldefrag(i["id"])[1] for i in resolveduri
if "id" in i))
else:
elif isinstance(resolveduri, dict):
processobj = resolveduri
else:
raise Exception("Must resolve to list or dict")

kwargs = kwargs.copy()
kwargs.update({
Expand All @@ -200,14 +229,19 @@ def make_tool(document_loader, avsc_names, metadata, uri, makeTool, kwargs):
return tool


def load_tool(argsworkflow, makeTool, kwargs=None,
enable_dev=False,
strict=True,
resolver=None):
# type: (Union[Text, dict[Text, Any]], Callable[...,Process], Dict[AnyStr, Any], bool, bool, Any) -> Any
document_loader, workflowobj, uri = fetch_document(argsworkflow, resolver=resolver)
def load_tool(argsworkflow, # type: Union[Text, Dict[Text, Any]]
makeTool, # type: Callable[..., Process]
kwargs=None, # type: dict
enable_dev=False, # type: bool
strict=True, # type: bool
resolver=None, # type: Callable[[Loader, Union[Text, dict[Text, Any]]], Text]
fetcher_constructor=DefaultFetcher # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher]
):
# type: (...) -> Process

document_loader, workflowobj, uri = fetch_document(argsworkflow, resolver=resolver, fetcher_constructor=fetcher_constructor)
document_loader, avsc_names, processobj, metadata, uri = validate_document(
document_loader, workflowobj, uri, enable_dev=enable_dev,
strict=strict)
strict=strict, fetcher_constructor=fetcher_constructor)
return make_tool(document_loader, avsc_names, metadata, uri,
makeTool, kwargs if kwargs else {})
37 changes: 21 additions & 16 deletions cwltool/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
import functools

import rdflib
import requests
from typing import (Union, Any, AnyStr, cast, Callable, Dict, Sequence, Text,
Tuple, Type, IO)

from schema_salad.ref_resolver import Loader
from schema_salad.ref_resolver import Loader, Fetcher, DefaultFetcher
import schema_salad.validate as validate
import schema_salad.jsonld_context
import schema_salad.makedoc
Expand Down Expand Up @@ -392,7 +393,7 @@ def generate_parser(toolparser, tool, namemap, records):

def load_job_order(args, t, stdin, print_input_deps=False, relative_deps=False,
stdout=sys.stdout, make_fs_access=None):
# type: (argparse.Namespace, Process, IO[Any], bool, bool, IO[Any], Type[StdFsAccess]) -> Union[int, Tuple[Dict[Text, Any], Text]]
# type: (argparse.Namespace, Process, IO[Any], bool, bool, IO[Any], Callable[[Text], StdFsAccess]) -> Union[int, Tuple[Dict[Text, Any], Text]]

job_order_object = None

Expand Down Expand Up @@ -553,18 +554,21 @@ def versionstring():
return u"%s %s" % (sys.argv[0], "unknown version")


def main(argsl=None,
args=None,
executor=single_job_executor,
makeTool=workflow.defaultMakeTool,
selectResources=None,
stdin=sys.stdin,
stdout=sys.stdout,
stderr=sys.stderr,
versionfunc=versionstring,
job_order_object=None,
make_fs_access=StdFsAccess):
# type: (List[str], argparse.Namespace, Callable[..., Union[Text, Dict[Text, Text]]], Callable[..., Process], Callable[[Dict[Text, int]], Dict[Text, int]], IO[Any], IO[Any], IO[Any], Callable[[], Text], Union[int, Tuple[Dict[Text, Any], Text]], Type[StdFsAccess]) -> int
def main(argsl=None, # type: List[str]
args=None, # type: argparse.Namespace
executor=single_job_executor, # type: Callable[..., Union[Text, Dict[Text, Text]]]
makeTool=workflow.defaultMakeTool, # type: Callable[..., Process]
selectResources=None, # type: Callable[[Dict[Text, int]], Dict[Text, int]]
stdin=sys.stdin, # type: IO[Any]
stdout=sys.stdout, # type: IO[Any]
stderr=sys.stderr, # type: IO[Any]
versionfunc=versionstring, # type: Callable[[], Text]
job_order_object=None, # type: Union[Tuple[Dict[Text, Any], Text], int]
make_fs_access=StdFsAccess, # type: Callable[[Text], StdFsAccess]
fetcher_constructor=DefaultFetcher, # type: Callable[[Dict[unicode, unicode], requests.sessions.Session], Fetcher]
resolver=tool_resolver
):
# type: (...) -> int

_logger.removeHandler(defaultStreamHandler)
stderr_handler = logging.StreamHandler(stderr)
Expand Down Expand Up @@ -624,7 +628,7 @@ def main(argsl=None,
draft2tool.ACCEPTLIST_RE = draft2tool.ACCEPTLIST_EN_RELAXED_RE

try:
document_loader, workflowobj, uri = fetch_document(args.workflow, resolver=tool_resolver)
document_loader, workflowobj, uri = fetch_document(args.workflow, resolver=resolver, fetcher_constructor=fetcher_constructor)

if args.print_deps:
printdeps(workflowobj, document_loader, stdout, args.relative_deps, uri)
Expand All @@ -633,7 +637,8 @@ def main(argsl=None,
document_loader, avsc_names, processobj, metadata, uri \
= validate_document(document_loader, workflowobj, uri,
enable_dev=args.enable_dev, strict=args.strict,
preprocess_only=args.print_pre or args.pack)
preprocess_only=args.print_pre or args.pack,
fetcher_constructor=fetcher_constructor)

if args.pack:
stdout.write(print_pack(document_loader, processobj, uri, metadata))
Expand Down
1 change: 0 additions & 1 deletion cwltool/stdfsaccess.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import glob
import os


class StdFsAccess(object):

def __init__(self, basedir): # type: (Text) -> None
Expand Down
45 changes: 45 additions & 0 deletions tests/test_fetch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import unittest
import schema_salad.ref_resolver
import schema_salad.main
import schema_salad.schema
from schema_salad.jsonld_context import makerdf
from pkg_resources import Requirement, resource_filename, ResolutionError # type: ignore
import rdflib
import ruamel.yaml as yaml
import json
import os

from cwltool.main import main
from cwltool.workflow import defaultMakeTool
from cwltool.load_tool import load_tool

class FetcherTest(unittest.TestCase):
def test_fetcher(self):
class TestFetcher(schema_salad.ref_resolver.Fetcher):
def __init__(self, a, b):
pass

def fetch_text(self, url): # type: (unicode) -> unicode
if url == "baz:bar/foo.cwl":
return """
cwlVersion: v1.0
class: CommandLineTool
baseCommand: echo
inputs: []
outputs: []
"""
else:
raise RuntimeError("Not foo.cwl")

def check_exists(self, url): # type: (unicode) -> bool
if url == "baz:bar/foo.cwl":
return True
else:
return False

def test_resolver(d, a):
return "baz:bar/" + a

load_tool("foo.cwl", defaultMakeTool, resolver=test_resolver, fetcher_constructor=TestFetcher)

self.assertEquals(0, main(["--print-pre", "--debug", "foo.cwl"], resolver=test_resolver, fetcher_constructor=TestFetcher))