diff --git a/src/pyff/api.py b/src/pyff/api.py index 22ba414a..f876c1d3 100644 --- a/src/pyff/api.py +++ b/src/pyff/api.py @@ -2,7 +2,7 @@ import threading from datetime import datetime, timedelta from json import dumps -from typing import Any, Iterable, List, Mapping +from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional, Tuple import pkg_resources import pyramid.httpexceptions as exc @@ -13,6 +13,7 @@ from lxml import etree from pyramid.config import Configurator from pyramid.events import NewRequest +from pyramid.request import Request from pyramid.response import Response from six import b from six.moves.urllib_parse import quote_plus @@ -22,27 +23,29 @@ from pyff.logs import get_log from pyff.pipes import plumbing from pyff.repo import MDRepository -from pyff.resource import Resource, ResourceInfo +from pyff.resource import Resource from pyff.samlmd import entity_display_name -from pyff.utils import b2u, dumptree, duration2timedelta, hash_id, json_serializer, utc_now +from pyff.utils import b2u, dumptree, hash_id, json_serializer, utc_now log = get_log(__name__) class NoCache(object): - def __init__(self): + """ Dummy implementation for when caching isn't enabled """ + + def __init__(self) -> None: pass - def __getitem__(self, item): + def __getitem__(self, item: Any) -> None: return None - def __setitem__(self, instance, value): + def __setitem__(self, instance: Any, value: Any) -> Any: return value -def robots_handler(request): +def robots_handler(request: Request) -> Response: """ - Impelements robots.txt + Implements robots.txt :param request: the HTTP request :return: robots.txt @@ -55,7 +58,7 @@ def robots_handler(request): ) -def status_handler(request): +def status_handler(request: Request) -> Response: """ Implements the /api/status endpoint @@ -80,34 +83,38 @@ def status_handler(request): class MediaAccept(object): - def __init__(self, accept): + def __init__(self, accept: str): self._type = AcceptableType(accept) - def has_key(self, key): + def has_key(self, key: Any) -> bool: # Literal[True]: return True - def get(self, item): + def get(self, item: Any) -> Any: return self._type.matches(item) - def __contains__(self, item): + def __contains__(self, item: Any) -> Any: return self._type.matches(item) - def __str__(self): + def __str__(self) -> str: return str(self._type) xml_types = ('text/xml', 'application/xml', 'application/samlmetadata+xml') -def _is_xml_type(accepter): +def _is_xml_type(accepter: MediaAccept) -> bool: return any([x in accepter for x in xml_types]) -def _is_xml(data): +def _is_xml(data: Any) -> bool: return isinstance(data, (etree._Element, etree._ElementTree)) -def _fmt(data, accepter): +def _fmt(data: Any, accepter: MediaAccept) -> Tuple[str, str]: + """ + Format data according to the accepted content type of the requester. + Return data as string (either XML or json) and a content-type. + """ if data is None or len(data) == 0: return "", 'text/plain' if _is_xml(data) and _is_xml_type(accepter): @@ -127,7 +134,7 @@ def call(entry: str) -> None: return None -def request_handler(request): +def request_handler(request: Request) -> Response: """ The main GET request handler for pyFF. Implements caching and forwards the request to process_handler @@ -146,7 +153,7 @@ def request_handler(request): return r -def process_handler(request): +def process_handler(request: Request) -> Response: """ The main request handler for pyFF. Implements API call hooks and content negotiation. @@ -155,7 +162,8 @@ def process_handler(request): """ _ctypes = {'xml': 'application/samlmetadata+xml;application/xml;text/xml', 'json': 'application/json'} - def _d(x, do_split=True): + def _d(x: Optional[str], do_split: bool = True) -> Tuple[Optional[str], Optional[str]]: + """ Split a path into a base component and an extension. """ if x is not None: x = x.strip() @@ -170,7 +178,7 @@ def _d(x, do_split=True): return x, None - log.debug(request) + log.debug(f'Processing request: {request}') if request.matchdict is None: raise exc.exception_response(400) @@ -182,18 +190,18 @@ def _d(x, do_split=True): pass entry = request.matchdict.get('entry', 'request') - path = list(request.matchdict.get('path', [])) + path_elem = list(request.matchdict.get('path', [])) match = request.params.get('q', request.params.get('query', None)) # Enable matching on scope. match = match.split('@').pop() if match and not match.endswith('@') else match log.debug("match={}".format(match)) - if 0 == len(path): - path = ['entities'] + if not path_elem: + path_elem = ['entities'] - alias = path.pop(0) - path = '/'.join(path) + alias = path_elem.pop(0) + path = '/'.join(path_elem) # Ugly workaround bc WSGI drops double-slashes. path = path.replace(':/', '://') @@ -226,23 +234,31 @@ def _d(x, do_split=True): accept = str(request.accept).split(',')[0] valid_accept = accept and not ('application/*' in accept or 'text/*' in accept or '*/*' in accept) - path_no_extension, extension = _d(path, True) - accept_from_extension = _ctypes.get(extension, accept) + new_path: Optional[str] = path + path_no_extension, extension = _d(new_path, True) + accept_from_extension = accept + if extension: + accept_from_extension = _ctypes.get(extension, accept) if policy == 'extension': - path = path_no_extension + new_path = path_no_extension if not valid_accept: accept = accept_from_extension elif policy == 'adaptive': if not valid_accept: - path = path_no_extension + new_path = path_no_extension accept = accept_from_extension - if pfx and path: - q = "{%s}%s" % (pfx, path) - path = "/%s/%s" % (alias, path) + if not accept: + log.warning('Could not determine accepted response type') + raise exc.exception_response(400) + + q: Optional[str] + if pfx and new_path: + q = f'{{{pfx}}}{new_path}' + new_path = f'/{alias}/{new_path}' else: - q = path + q = new_path try: accepter = MediaAccept(accept) @@ -254,18 +270,19 @@ def _d(x, do_split=True): 'url': request.current_route_url(), 'select': q, 'match': match.lower() if match else match, - 'path': path, + 'path': new_path, 'stats': {}, } r = p.process(request.registry.md, state=state, raise_exceptions=True, scheduler=request.registry.scheduler) - log.debug(r) + log.debug(f'Plumbing process result: {r}') if r is None: r = [] response = Response() - response.headers.update(state.get('headers', {})) - ctype = state.get('headers').get('Content-Type', None) + _headers = state.get('headers', {}) + response.headers.update(_headers) + ctype = _headers.get('Content-Type', None) if not ctype: r, t = _fmt(r, accepter) ctype = t @@ -280,20 +297,20 @@ def _d(x, do_split=True): import traceback log.debug(traceback.format_exc()) - log.warning(ex) + log.warning(f'Exception from processing pipeline: {ex}') raise exc.exception_response(409) except BaseException as ex: import traceback log.debug(traceback.format_exc()) - log.error(ex) + log.error(f'Exception from processing pipeline: {ex}') raise exc.exception_response(500) if request.method == 'GET': raise exc.exception_response(404) -def webfinger_handler(request): +def webfinger_handler(request: Request) -> Response: """An implementation the webfinger protocol (http://tools.ietf.org/html/draft-ietf-appsawg-webfinger-12) in order to provide information about up and downstream metadata available at @@ -324,7 +341,7 @@ def webfinger_handler(request): "subject": "http://reep.refeds.org:8080" } - Depending on which version of pyFF your're running and the configuration you + Depending on which version of pyFF you're running and the configuration you may also see downstream metadata listed using the 'role' attribute to the link elements. """ @@ -335,11 +352,11 @@ def webfinger_handler(request): if resource is None: resource = request.host_url - jrd = dict() - dt = datetime.now() + duration2timedelta("PT1H") + jrd: Dict[str, Any] = dict() + dt = datetime.now() + timedelta(hours=1) jrd['expires'] = dt.isoformat() jrd['subject'] = request.host_url - links = list() + links: List[Dict[str, Any]] = list() jrd['links'] = links _dflt_rels = { @@ -352,7 +369,7 @@ def webfinger_handler(request): else: rel = [rel] - def _links(url, title=None): + def _links(url: str, title: Any = None) -> None: if url.startswith('/'): url = url.lstrip('/') for r in rel: @@ -381,7 +398,7 @@ def _links(url, title=None): return response -def resources_handler(request): +def resources_handler(request: Request) -> Response: """ Implements the /api/resources endpoint @@ -409,7 +426,7 @@ def _info(r: Resource) -> Mapping[str, Any]: return response -def pipeline_handler(request): +def pipeline_handler(request: Request) -> Response: """ Implements the /api/pipeline endpoint @@ -422,7 +439,7 @@ def pipeline_handler(request): return response -def search_handler(request): +def search_handler(request: Request) -> Response: """ Implements the /api/search endpoint @@ -438,7 +455,7 @@ def search_handler(request): log.debug("match={}".format(match)) store = request.registry.md.store - def _response(): + def _response() -> Generator[bytes, bytes, None]: yield b('[') in_loop = False entities = store.search(query=match.lower(), entity_filter=entity_filter) @@ -454,8 +471,8 @@ def _response(): return response -def add_cors_headers_response_callback(event): - def cors_headers(request, response): +def add_cors_headers_response_callback(event: NewRequest) -> None: + def cors_headers(request: Request, response: Response) -> None: response.headers.update( { 'Access-Control-Allow-Origin': '*', @@ -469,7 +486,7 @@ def cors_headers(request, response): event.request.add_response_callback(cors_headers) -def launch_memory_usage_server(port=9002): +def launch_memory_usage_server(port: int = 9002) -> None: import cherrypy import dowser @@ -479,7 +496,7 @@ def launch_memory_usage_server(port=9002): cherrypy.engine.start() -def mkapp(*args, **kwargs): +def mkapp(*args: Any, **kwargs: Any) -> Any: md = kwargs.pop('md', None) if md is None: md = MDRepository() @@ -501,7 +518,9 @@ def mkapp(*args, **kwargs): for mn in config.modules: importlib.import_module(mn) - pipeline = args or None + pipeline = None + if args: + pipeline = list(args) if pipeline is None and config.pipeline: pipeline = [config.pipeline] diff --git a/src/pyff/builtins.py b/src/pyff/builtins.py index 316c1fcc..3828839c 100644 --- a/src/pyff/builtins.py +++ b/src/pyff/builtins.py @@ -13,6 +13,7 @@ from copy import deepcopy from datetime import datetime from distutils.util import strtobool +from typing import Dict, Optional import ipaddr import six @@ -60,7 +61,7 @@ @pipe -def dump(req, *opts): +def dump(req: Plumbing.Request, *opts): """ Print a representation of the entities set on stdout. Useful for testing. @@ -76,7 +77,7 @@ def dump(req, *opts): @pipe(name="map") -def _map(req, *opts): +def _map(req: Plumbing.Request, *opts): """ loop over the entities in a selection @@ -112,7 +113,7 @@ def _p(e): @pipe(name="then") -def _then(req, *opts): +def _then(req: Plumbing.Request, *opts): """ Call a named 'when' clause and return - akin to macro invocations for pyFF """ @@ -122,16 +123,16 @@ def _then(req, *opts): @pipe(name="log_entity") -def _log_entity(req, *opts): +def _log_entity(req: Plumbing.Request, *opts): """ log the request id as it is processed (typically the entity_id) """ - log.info(req.id) + log.info(str(req.id)) return req.t @pipe(name="print") -def _print_t(req, *opts): +def _print_t(req: Plumbing.Request, *opts): """ Print whatever is in the active tree without transformation @@ -148,7 +149,9 @@ def _print_t(req, *opts): output: "somewhere.foo" """ - fn = req.args.get('output', None) + fn = None + if isinstance(req.args, dict): + fn = req.args.get('output', None) if fn is not None: safe_write(fn, req.t) else: @@ -156,7 +159,7 @@ def _print_t(req, *opts): @pipe -def end(req, *opts): +def end(req: Plumbing.Request, *opts): """ Exit with optional error code and message. @@ -176,7 +179,7 @@ def end(req, *opts): """ code = 0 - if req.args is not None: + if isinstance(req.args, dict): code = req.args.get('code', 0) msg = req.args.get('message', None) if msg is not None: @@ -185,7 +188,7 @@ def end(req, *opts): @pipe -def fork(req, *opts): +def fork(req: Plumbing.Request, *opts): """ Make a copy of the working tree and process the arguments as a pipleline. This essentially resets the working tree and allows a new plumbing to run. Useful for producing multiple outputs from a single source. @@ -250,7 +253,10 @@ def fork(req, *opts): if req.t is not None: nt = deepcopy(req.t) - ip = Plumbing(pipeline=req.args, pid="%s.fork" % req.plumbing.pid) + if not isinstance(req.args, list): + raise ValueError('Non-list arguments to "fork" not allowed') + + ip = Plumbing(pipeline=req.args, pid=f'{req.plumbing.pid}.fork') ireq = Plumbing.Request(ip, req.md, t=nt, scheduler=req.scheduler) ireq.set_id(req.id) ireq.set_parent(req) @@ -279,7 +285,7 @@ def _any(lst, d): @pipe(name='break') -def _break(req, *opts): +def _break(req: Plumbing.Request, *opts): """ Break out of a pipeline. @@ -305,7 +311,7 @@ def _break(req, *opts): @pipe(name='pipe') -def _pipe(req, *opts): +def _pipe(req: Plumbing.Request, *opts): """ Run the argument list as a pipleine. @@ -344,7 +350,10 @@ def _pipe(req, *opts): - two """ - ot = Plumbing(pipeline=req.args, pid="%s.pipe" % req.plumbing.id).iprocess(req) + if not isinstance(req.args, list): + raise ValueError('Non-list arguments to "pipe" not allowed') + + ot = Plumbing(pipeline=req.args, pid=f'{req.plumbing.id}.pipe').iprocess(req) req.done = False return ot @@ -378,12 +387,15 @@ def when(req: Plumbing.Request, condition: str, *values): if c is None: log.debug(f'Condition {repr(condition)} not present in state {req.state}') if c is not None and (not values or _any(values, c)): + if not isinstance(req.args, list): + raise ValueError('Non-list arguments to "when" not allowed') + return Plumbing(pipeline=req.args, pid="%s.when" % req.plumbing.id).iprocess(req) return req.t @pipe -def info(req, *opts): +def info(req: Plumbing.Request, *opts): """ Dumps the working document on stdout. Useful for testing. @@ -401,7 +413,7 @@ def info(req, *opts): @pipe -def sort(req, *opts): +def sort(req: Plumbing.Request, *opts): """ Sorts the working entities by the value returned by the given xpath. By default, entities are sorted by 'entityID' when the 'order_by [xpath]' option is omitted and @@ -424,15 +436,16 @@ def sort(req, *opts): if req.t is None: raise PipeException("Unable to sort empty document.") - opts = dict(list(zip(opts[0:1], [" ".join(opts[1:])]))) - opts.setdefault('order_by', None) - sort_entities(req.t, opts['order_by']) + _opts: Dict[str, Optional[str]] = dict(list(zip(opts[0:1], [" ".join(opts[1:])]))) + if 'order_by' not in _opts: + _opts['order_by'] = None + sort_entities(req.t, _opts['order_by']) return req.t @pipe -def publish(req, *opts): +def publish(req: Plumbing.Request, *opts): """ Publish the working document in XML form. @@ -473,12 +486,12 @@ def publish(req, *opts): if req.args is None: raise PipeException("Publish must at least specify output") - if type(req.args) is not dict: + if not isinstance(req.args, dict): req.args = dict(output=req.args[0]) for t in ('raw', 'update_store', 'hash_link', 'urlencode_filenames'): if t in req.args and type(req.args[t]) is not bool: - req.args[t] = strtobool("{}".format(req.args[t])) + req.args[t] = strtobool(str(req.args[t])) req.args.setdefault('ext', '.xml') req.args.setdefault('output_file', 'output') @@ -535,7 +548,7 @@ def _nop(x): @pipe @deprecated(reason="stats subsystem was removed") -def loadstats(req, *opts): +def loadstats(req: Plumbing.Request, *opts): """ Log (INFO) information about the result of the last call to load @@ -549,7 +562,7 @@ def loadstats(req, *opts): @pipe @deprecated(reason="replaced with load") -def remote(req, *opts): +def remote(req: Plumbing.Request, *opts): """ Deprecated. Calls :py:mod:`pyff.pipes.builtins.load`. """ @@ -558,7 +571,7 @@ def remote(req, *opts): @pipe @deprecated(reason="replaced with load") -def local(req, *opts): +def local(req: Plumbing.Request, *opts): """ Deprecated. Calls :py:mod:`pyff.pipes.builtins.load`. """ @@ -567,17 +580,17 @@ def local(req, *opts): @pipe @deprecated(reason="replaced with load") -def _fetch(req, *opts): +def _fetch(req: Plumbing.Request, *opts): return load(req, *opts) @pipe -def load(req, *opts): +def load(req: Plumbing.Request, *opts): """ General-purpose resource fetcher. :param req: The request - :param opts: Options: See "Options" below + :param _opts: Options: See "Options" below :return: None Supports both remote and local resources. Fetching remote resources is done in parallel using threads. @@ -607,20 +620,22 @@ def load(req, *opts): fail_on_error controls whether failure to validating the entire MD file will abort processing of the pipeline. """ - opts = dict(list(zip(opts[::2], opts[1::2]))) - opts.setdefault('timeout', 120) - opts.setdefault('max_workers', 5) - opts.setdefault('validate', "True") - opts.setdefault('fail_on_error', "False") - opts.setdefault('filter_invalid', "True") - opts['validate'] = bool(strtobool(opts['validate'])) - opts['fail_on_error'] = bool(strtobool(opts['fail_on_error'])) - opts['filter_invalid'] = bool(strtobool(opts['filter_invalid'])) - - remotes = [] + _opts = dict(list(zip(opts[::2], opts[1::2]))) + _opts.setdefault('timeout', 120) + _opts.setdefault('max_workers', 5) + _opts.setdefault('validate', "True") + _opts.setdefault('fail_on_error', "False") + _opts.setdefault('filter_invalid', "True") + _opts['validate'] = bool(strtobool(_opts['validate'])) + _opts['fail_on_error'] = bool(strtobool(_opts['fail_on_error'])) + _opts['filter_invalid'] = bool(strtobool(_opts['filter_invalid'])) + + if not isinstance(req.args, list): + raise ValueError('Non-list args to "load" not allowed') + for x in req.args: x = x.strip() - log.debug("load parsing '%s'" % x) + log.debug(f"load parsing '{x}'") r = x.split() assert len(r) in range(1, 8), PipeException( @@ -656,12 +671,12 @@ def load(req, *opts): child_opts.verify = elt # override anything in child_opts with what is in opts - child_opts = child_opts.copy(update=opts) + child_opts = child_opts.copy(update=_opts) req.md.rm.add_child(url, child_opts) log.debug("Refreshing all resources") - req.md.rm.reload(fail_on_error=bool(opts['fail_on_error'])) + req.md.rm.reload(fail_on_error=bool(_opts['fail_on_error'])) def _select_args(req): @@ -681,7 +696,7 @@ def _select_args(req): @pipe -def select(req, *opts): +def select(req: Plumbing.Request, *opts): """ Select a set of EntityDescriptor elements as the working document. @@ -814,7 +829,7 @@ def _match(q, elt): @pipe(name="filter") -def _filter(req, *opts): +def _filter(req: Plumbing.Request, *opts): """ Refines the working document by applying a filter. The filter expression is a subset of the @@ -864,7 +879,7 @@ def _filter(req, *opts): @pipe -def pick(req, *opts): +def pick(req: Plumbing.Request, *opts): """ Select a set of EntityDescriptor elements as a working document but don't validate it. @@ -884,7 +899,7 @@ def pick(req, *opts): @pipe -def first(req, *opts): +def first(req: Plumbing.Request, *opts): """ If the working document is a single EntityDescriptor, strip the outer EntitiesDescriptor element and return it. @@ -914,7 +929,7 @@ def first(req, *opts): @pipe(name='discojson') -def _discojson(req, *opts): +def _discojson(req: Plumbing.Request, *opts): """ Return a discojuice-compatible json representation of the tree @@ -941,7 +956,7 @@ def _discojson(req, *opts): @pipe -def sign(req, *opts): +def sign(req: Plumbing.Request, *_opts): """ Sign the working document. @@ -990,7 +1005,7 @@ def sign(req, *opts): if req.t is None: raise PipeException("Your pipeline is missing a select statement.") - if not type(req.args) is dict: + if not isinstance(req.args, dict): raise PipeException("Missing key and cert arguments to sign pipe") key_file = req.args.get('key', None) @@ -1006,14 +1021,14 @@ def sign(req, *opts): relt = root(req.t) idattr = relt.get('ID') if idattr: - opts['reference_uri'] = "#%s" % idattr + opts['reference_uri'] = f'#{idattr}' xmlsec.sign(req.t, key_file, cert_file, **opts) return req.t @pipe -def stats(req, *opts): +def stats(req: Plumbing.Request, *opts): """ Display statistics about the current working document. @@ -1050,7 +1065,7 @@ def stats(req, *opts): @pipe -def summary(req, *opts): +def summary(req: Plumbing.Request, *opts): """ Display a summary of the repository @@ -1066,7 +1081,7 @@ def summary(req, *opts): @pipe(name='store') -def _store(req, *opts): +def _store(req: Plumbing.Request, *opts): """ Save the working document as separate files @@ -1086,8 +1101,7 @@ def _store(req, *opts): if not req.args: raise PipeException("store requires an argument") - target_dir = None - if type(req.args) is dict: + if isinstance(req.args, dict): target_dir = req.args.get('directory', None) else: target_dir = req.args[0] @@ -1102,7 +1116,7 @@ def _store(req, *opts): @pipe -def xslt(req, *opts): +def xslt(req: Plumbing.Request, *opts): """ Transform the working document using an XSLT file. @@ -1128,6 +1142,9 @@ def xslt(req, *opts): if req.t is None: raise PipeException("Your plumbing is missing a select statement.") + if not isinstance(req.args, dict): + raise ValueError('Non-dict args to "xslt" not allowed') + stylesheet = req.args.get('stylesheet', None) if stylesheet is None: raise PipeException("xslt requires stylesheet") @@ -1142,7 +1159,7 @@ def xslt(req, *opts): @pipe -def validate(req, *opts): +def validate(req: Plumbing.Request, *opts): """ Validate the working document @@ -1163,7 +1180,7 @@ def validate(req, *opts): @pipe -def prune(req, *opts): +def prune(req: Plumbing.Request, *opts): """ Prune the active tree, removing all elements matching @@ -1193,6 +1210,9 @@ def prune(req, *opts): if req.t is None: raise PipeException("Your pipeline is missing a select statement.") + if not isinstance(req.args, list): + raise ValueError('Non-list args to "prune" not allowed') + for path in req.args: for part in req.t.iterfind(path): parent = part.getparent() @@ -1205,7 +1225,7 @@ def prune(req, *opts): @pipe -def check_xml_namespaces(req, *opts): +def check_xml_namespaces(req: Plumbing.Request, *opts): """ Ensure that all namespaces are http or httpd scheme URLs. @@ -1232,7 +1252,7 @@ def _verify(elt): @pipe -def drop_xsi_type(req, *opts): +def drop_xsi_type(req: Plumbing.Request, *opts): """ Remove all xsi namespaces from the tree. @@ -1255,7 +1275,7 @@ def _drop_xsi_type(elt): @pipe -def certreport(req, *opts): +def certreport(req: Plumbing.Request, *opts): """ Generate a report of the certificates (optionally limited by expiration time or key size) found in the selection. @@ -1289,7 +1309,7 @@ def certreport(req, *opts): if not req.args: req.args = {} - if type(req.args) is not dict: + if not isinstance(req.args, dict): raise PipeException("usage: certreport {warning: 864000, error: 0}") error_seconds = int(req.args.get('error_seconds', "0")) @@ -1297,7 +1317,7 @@ def certreport(req, *opts): error_bits = int(req.args.get('error_bits', "1024")) warning_bits = int(req.args.get('warning_bits', "2048")) - seen = {} + seen: Dict[str, bool] = {} for eid in req.t.xpath("//md:EntityDescriptor/@entityID", namespaces=NS, smart_strings=False): for cd in req.t.xpath( "md:EntityDescriptor[@entityID='%s']//ds:X509Certificate" % eid, namespaces=NS, smart_strings=False @@ -1308,9 +1328,9 @@ def certreport(req, *opts): m = hashlib.sha1() m.update(cert_der) fp = m.hexdigest() - if not seen.get(fp, False): - entity_elt = cd.getparent().getparent().getparent().getparent().getparent() + if fp not in seen: seen[fp] = True + entity_elt = cd.getparent().getparent().getparent().getparent().getparent() cdict = xmlsec.utils.b642cert(cert_pem) keysize = cdict['modulus'].bit_length() cert = cdict['cert'] @@ -1371,11 +1391,11 @@ def certreport(req, *opts): req.store.update(entity_elt) except Exception as ex: log.debug(traceback.format_exc()) - log.error(ex) + log.error(f'Got exception while creating certreport: {ex}') @pipe -def emit(req, ctype="application/xml", *opts): +def emit(req: Plumbing.Request, ctype="application/xml", *opts): """ Returns a UTF-8 encoded representation of the working tree. @@ -1426,7 +1446,7 @@ def emit(req, ctype="application/xml", *opts): @pipe -def signcerts(req, *opts): +def signcerts(req: Plumbing.Request, *opts): """ Logs the fingerprints of the signing certs found in the current working tree. @@ -1452,7 +1472,7 @@ def signcerts(req, *opts): @pipe -def finalize(req, *opts): +def finalize(req: Plumbing.Request, *opts): """ Prepares the working document for publication/rendering. @@ -1488,6 +1508,9 @@ def finalize(req, *opts): if req.t is None: raise PipeException("Your plumbing is missing a select statement.") + if not isinstance(req.args, dict): + raise ValueError('Non-dict args to "finalize" not allowed') + e = root(req.t) if e.tag == "{%s}EntitiesDescriptor" % NS['md']: name = req.args.get('name', None) @@ -1500,10 +1523,13 @@ def finalize(req, *opts): try: name_url = urlparse(name) base_url = urlparse(req.args.get('baseURL')) - name = "{}://{}{}".format(base_url.scheme, base_url.netloc, name_url.path) + # TODO: Investigate this error, which is probably correct: + # error: On Python 3 '{}'.format(b'abc') produces "b'abc'", not 'abc'; + # use '{!r}'.format(b'abc') if this is desired behavior + name = "{}://{}{}".format(base_url.scheme, base_url.netloc, name_url.path) # type: ignore log.debug("-------- using Name: %s" % name) except ValueError as ex: - log.debug(ex) + log.debug(f'Got an exception while finalizing: {ex}') name = None if name is None or 0 == len(name): name = e.get('Name', None) @@ -1541,7 +1567,9 @@ def finalize(req, *opts): # set a reasonable default: 50% of the validity # we replace this below if we have cacheDuration set # TODO: offset can be None here, if validUntil is not a valid duration or ISO date - req.state['cache'] = int(total_seconds(offset) / 50) + # What is the right action to take then? + if offset: + req.state['cache'] = int(total_seconds(offset) / 50) cache_duration = req.args.get('cacheDuration', e.get('cacheDuration', None)) if cache_duration is not None and len(cache_duration) > 0: @@ -1556,7 +1584,7 @@ def finalize(req, *opts): @pipe(name='reginfo') -def _reginfo(req, *opts): +def _reginfo(req: Plumbing.Request, *opts): """ Sets registration info extension on EntityDescription element @@ -1580,6 +1608,9 @@ def _reginfo(req, *opts): if req.t is None: raise PipeException("Your pipeline is missing a select statement.") + if not isinstance(req.args, dict): + raise ValueError('Non-dict args to "reginfo" not allowed') + for e in iter_entities(req.t): set_reginfo(e, **req.args) @@ -1587,7 +1618,7 @@ def _reginfo(req, *opts): @pipe(name='pubinfo') -def _pubinfo(req, *opts): +def _pubinfo(req: Plumbing.Request, *opts): """ Sets publication info extension on EntityDescription element @@ -1609,13 +1640,16 @@ def _pubinfo(req, *opts): if req.t is None: raise PipeException("Your pipeline is missing a select statement.") + if not isinstance(req.args, dict): + raise ValueError('Non-dict args to "pubinfo" not allowed') + set_pubinfo(root(req.t), **req.args) return req.t @pipe(name='setattr') -def _setattr(req, *opts): +def _setattr(req: Plumbing.Request, *opts): """ Sets entity attributes on the working document @@ -1651,7 +1685,7 @@ def _setattr(req, *opts): @pipe(name='nodecountry') -def _nodecountry(req, *opts): +def _nodecountry(req: Plumbing.Request, *opts): """ Sets eidas:NodeCountry @@ -1675,6 +1709,9 @@ def _nodecountry(req, *opts): if req.t is None: raise PipeException("Your pipeline is missing a select statement.") + if not isinstance(req.args, dict): + raise ValueError('Non-dict args to "nodecountry" not allowed') + for e in iter_entities(req.t): if req.args is not None and 'country' in req.args: set_nodecountry(e, country_code=req.args['country']) diff --git a/src/pyff/logs.py b/src/pyff/logs.py index edee1dc7..6ab1419e 100644 --- a/src/pyff/logs.py +++ b/src/pyff/logs.py @@ -3,6 +3,7 @@ import logging import os import syslog +from typing import Any, Optional import six @@ -35,25 +36,25 @@ def _l(self, severity, msg): else: raise ValueError("unknown severity %s" % severity) - def warn(self, msg): + def warn(self, msg: str) -> Any: return self._l(logging.WARN, msg) - def warning(self, msg): + def warning(self, msg: str) -> Any: return self._l(logging.WARN, msg) - def info(self, msg): + def info(self, msg: str) -> Any: return self._l(logging.INFO, msg) - def error(self, msg): + def error(self, msg: str) -> Any: return self._l(logging.ERROR, msg) - def critical(self, msg): + def critical(self, msg: str) -> Any: return self._l(logging.CRITICAL, msg) - def debug(self, msg): + def debug(self, msg: str) -> Any: return self._l(logging.DEBUG, msg) - def isEnabledFor(self, lvl): + def isEnabledFor(self, lvl: Any) -> bool: return self._log.isEnabledFor(lvl) @@ -64,7 +65,7 @@ def get_log(name: str) -> PyFFLogger: log = get_log('pyff') -def log_config_file(ini): +def log_config_file(ini: Optional[str]) -> None: if ini is not None: import logging.config @@ -75,7 +76,7 @@ def log_config_file(ini): logging.config.fileConfig(ini) -log_config_file(os.getenv('PYFF_LOGGING', None)) +log_config_file(os.getenv('PYFF_LOGGING')) # http://www.aminus.org/blogs/index.php/2008/07/03/writing-high-efficiency-large-python-sys-1?blog=2 # blog post explicitly gives permission for use diff --git a/src/pyff/pipes.py b/src/pyff/pipes.py index 308ae83d..29abd427 100644 --- a/src/pyff/pipes.py +++ b/src/pyff/pipes.py @@ -7,11 +7,11 @@ import functools import os import traceback -from typing import Any, Callable, Dict, Iterable, Optional, Type +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union import yaml from apscheduler.schedulers.background import BackgroundScheduler -from lxml.etree import ElementTree +from lxml.etree import Element, ElementTree from pyff.logs import get_log from pyff.repo import MDRepository @@ -90,17 +90,17 @@ def the_something_func(req,*opts): # self[entry_point.name] = entry_point.load() -def load_pipe(d): +def load_pipe(d: Any) -> Tuple[Callable, Any, str, Optional[Union[str, Dict, List]]]: """Return a triple callable,name,args of the pipe specified by the object d. :param d: The following alternatives for d are allowed: - d is a string (or unicode) in which case the pipe is named d called with None as args. - d is a dict of the form {name: args} (i.e one key) in which case the pipe named *name* is called with args - - d is an iterable (eg tuple or list) in which case d[0] is treated as the pipe name and d[1:] becomes the args + - d is an iterable (a list) in which case d[0] is treated as the pipe name and d[1:] becomes the args """ - def _n(_d): + def _n(_d: str) -> Tuple[str, List[str]]: lst = _d.split() _name = lst[0] _opts = lst[1:] @@ -108,7 +108,7 @@ def _n(_d): name = None args = None - opts = [] + opts: List[str] = [] if is_text(d): name, opts = _n(d) elif hasattr(d, '__iter__') and not type(d) is dict: @@ -140,7 +140,7 @@ class PipelineCallback(object): A delayed pipeline callback used as a post for parse_saml_metadata """ - def __init__(self, entry_point, req, store=None): + def __init__(self, entry_point: str, req: Plumbing.Request, store: Optional[SAMLStoreBase] = None) -> None: self.entry_point = entry_point self.plumbing = Plumbing(req.scope_of(entry_point).plumbing.pipeline, f"{req.plumbing.id}-via-{entry_point}") self.req = req @@ -152,13 +152,15 @@ def __str__(self) -> str: def __repr__(self) -> str: return str(self) - def __copy__(self): + def __copy__(self) -> PipelineCallback: + # TODO: This seems... dangerous. What's the need for this? return self - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: Any) -> PipelineCallback: + # TODO: This seems... dangerous. What's the need for this? return self - def __call__(self, *args, **kwargs): + def __call__(self, *args: Any, **kwargs: Any) -> Any: log.debug("{!s}: called".format(self.plumbing)) t = args[0] if t is None: @@ -170,7 +172,7 @@ def __call__(self, *args, **kwargs): return self.plumbing.process(self.req.md, store=self.store, state=state, t=t) except Exception as ex: log.debug(traceback.format_exc()) - log.error(ex) + log.error(f'Got an exception executing the plumbing process: {ex}') raise ex @@ -222,7 +224,7 @@ def pid(self) -> str: def __iter__(self) -> Iterable[Dict[str, Any]]: return self.pipeline - def __str__(self): + def __str__(self) -> str: return "PL[id={!s}]".format(self.pid) class Request(object): @@ -250,9 +252,9 @@ def __init__( self.plumbing: Plumbing = pl self.md: MDRepository = md self.t: ElementTree = t - self._id = None + self._id: Optional[str] = None self.name = name - self.args: Iterable[Dict[str, Any]] = args + self.args: Optional[Union[str, Dict, List]] = args self.state: Dict[str, Any] = state self.done: bool = False self._store: SAMLStoreBase = store @@ -261,16 +263,16 @@ def __init__( self.exception: Optional[BaseException] = None self.parent: Optional[Plumbing.Request] = None - def scope_of(self, entry_point): - if 'with {}'.format(entry_point) in self.plumbing.pipeline: - return self - elif self.parent is None: + def scope_of(self, entry_point: str) -> Plumbing.Request: + for _p in self.plumbing.pipeline: + if f'with {entry_point}' in _p: + return self + if self.parent is None: return self - else: - return self.parent.scope_of(entry_point) + return self.parent.scope_of(entry_point) @property - def id(self): + def id(self) -> Optional[str]: if self.t is None: return None if self._id is None: @@ -279,10 +281,10 @@ def id(self): self._id = self.t.get('Name') return self._id - def set_id(self, _id): + def set_id(self, _id: Optional[str]) -> None: self._id = _id - def set_parent(self, _parent): + def set_parent(self, _parent: Optional[Plumbing.Request]) -> None: self.parent = _parent @property @@ -291,14 +293,14 @@ def store(self) -> SAMLStoreBase: return self._store return self.md.store - def process(self, pl: Plumbing): + def process(self, pl: Plumbing) -> ElementTree: """The inner request pipeline processor. :param pl: The plumbing to run this request through """ return pl.iprocess(self) - def iprocess(self, req: Plumbing.Request): + def iprocess(self, req: Plumbing.Request) -> ElementTree: """The inner request pipeline processor. :param req: The request to run through the pipeline @@ -325,7 +327,7 @@ def iprocess(self, req: Plumbing.Request): break except BaseException as ex: log.debug(traceback.format_exc()) - log.error(ex) + log.error(f'Got exception when loading/executing pipe: {ex}') req.exception = ex if req.raise_exceptions: raise ex @@ -335,13 +337,13 @@ def iprocess(self, req: Plumbing.Request): def process( self, md: MDRepository, - args=None, + args: Any = None, state: Optional[Dict[str, Any]] = None, - t=None, - store=None, + t: Optional[ElementTree] = None, + store: Optional[SAMLStoreBase] = None, raise_exceptions: bool = True, - scheduler=None, - ): + scheduler: Optional[BackgroundScheduler] = None, + ) -> Optional[Element]: # TODO: unsure about this return type """ The main entrypoint for processing a request pipeline. Calls the inner processor. diff --git a/src/pyff/repo.py b/src/pyff/repo.py index a608c6a7..aafe1a7c 100644 --- a/src/pyff/repo.py +++ b/src/pyff/repo.py @@ -13,7 +13,7 @@ class MDRepository: """A class representing a set of SAML metadata and the resources from where this metadata was loaded.""" - def __init__(self, scheduler=None): + def __init__(self, scheduler=None) -> None: random.seed(self) self.rm = Resource(url=None, opts=ResourceOpts()) # root if scheduler is None: diff --git a/src/pyff/resource.py b/src/pyff/resource.py index 83beb9b7..196c9683 100644 --- a/src/pyff/resource.py +++ b/src/pyff/resource.py @@ -198,7 +198,7 @@ class ResourceInfo(BaseModel): class Config: arbitrary_types_allowed = True - def to_dict(self): + def to_dict(self) -> Dict[str, Any]: def _format_key(k: str) -> str: special = {'http_headers': 'HTTP Response Headers'} if k in special: diff --git a/src/pyff/samlmd.py b/src/pyff/samlmd.py index e507168c..0b07a236 100644 --- a/src/pyff/samlmd.py +++ b/src/pyff/samlmd.py @@ -8,7 +8,7 @@ from lxml import etree from lxml.builder import ElementMaker -from lxml.etree import DocumentInvalid, ElementTree +from lxml.etree import DocumentInvalid, Element, ElementTree from pydantic import Field from xmlsec.crypto import CertDict @@ -740,7 +740,7 @@ def entity_extended_display(entity, langs=None): return display.strip(), info.strip() -def entity_display_name(entity, langs=None): +def entity_display_name(entity: Element, langs=None) -> str: """Utility-method for computing a displayable string for a given entity. :param entity: An EntityDescriptor element diff --git a/src/pyff/tools.py b/src/pyff/tools.py index aa556ade..dae12883 100644 --- a/src/pyff/tools.py +++ b/src/pyff/tools.py @@ -14,7 +14,7 @@ from xmldiff.main import diff_trees from pyff.constants import config, parse_options -from pyff.resource import Resource +from pyff.resource import Resource, ResourceOpts from pyff.samlmd import diff, iter_entities from pyff.store import MemoryStore @@ -31,12 +31,12 @@ def difftool(): try: rm = Resource() - r1 = Resource(args[0]) - r2 = Resource(args[1]) rm.add(r1) rm.add(r2) store = MemoryStore() rm.reload(store=store) + r1 = Resource(url=args[0], opts=ResourceOpts()) + r2 = Resource(url=args[1], opts=ResourceOpts()) status = 0 if r1.t.get('Name') != r2.t.get('Name'): diff --git a/src/pyff/utils.py b/src/pyff/utils.py index 9a6c5c1c..6c87ee7e 100644 --- a/src/pyff/utils.py +++ b/src/pyff/utils.py @@ -18,25 +18,25 @@ import threading import time import traceback +from _collections_abc import Mapping, MutableMapping from copy import copy from datetime import datetime, timedelta, timezone from email.utils import parsedate from itertools import chain from threading import local from time import gmtime, strftime -from typing import BinaryIO, Callable, Optional, Union +from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union import pkg_resources import requests import xmlsec -from _collections_abc import Mapping, MutableMapping from apscheduler.executors.pool import ThreadPoolExecutor from apscheduler.jobstores.memory import MemoryJobStore from apscheduler.jobstores.redis import RedisJobStore from apscheduler.schedulers.background import BackgroundScheduler from cachetools import LRUCache from lxml import etree -from lxml.etree import ElementTree +from lxml.etree import Element, ElementTree from requests import Session from requests.adapters import BaseAdapter, HTTPAdapter, Response from requests.packages.urllib3.util.retry import Retry @@ -149,11 +149,11 @@ def totimestamp(dt: datetime, epoch=datetime(1970, 1, 1)) -> int: return int(ts) -def dumptree(t, pretty_print=False, method='xml', xml_declaration=True): +def dumptree(t: ElementTree, pretty_print: bool = False, method: str = 'xml', xml_declaration: bool = True) -> str: """ Return a string representation of the tree, optionally pretty_print(ed) (default False) - :param t: An ElemenTree to serialize + :param t: An ElementTree to serialize """ return etree.tostring( t, encoding='UTF-8', method=method, xml_declaration=xml_declaration, pretty_print=pretty_print @@ -390,25 +390,28 @@ def duration2timedelta(period: str) -> Optional[timedelta]: return delta -def _lang(elt, default_lang): +def _lang(elt: Element, default_lang: Optional[str]) -> Optional[str]: return elt.get("{http://www.w3.org/XML/1998/namespace}lang", default_lang) -def lang_dict(elts, getter=lambda e: e, default_lang=None): +def lang_dict(elts: Sequence[Element], getter=lambda e: e, default_lang: Optional[str] = None) -> Dict[str, Callable]: if default_lang is None: default_lang = config.langs[0] r = dict() for e in elts: - r[_lang(e, default_lang)] = getter(e) + _l = _lang(e, default_lang) + if not _l: + raise ValueError('Could not get lang from element, and no default provided') + r[_l] = getter(e) return r -def find_lang(elts, lang, default_lang): +def find_lang(elts: Sequence[Element], lang: str, default_lang: str) -> Element: return next((e for e in elts if _lang(e, default_lang) == lang), elts[0]) -def filter_lang(elts, langs=None): +def filter_lang(elts: Any, langs: Optional[Sequence[str]] = None) -> List[Element]: if langs is None or type(langs) is not list: langs = config.langs @@ -422,6 +425,9 @@ def filter_lang(elts, langs=None): if len(elts) == 0: return [] + if not langs: + raise RuntimeError('Configuration is missing langs') + dflt = langs[0] lst = [find_lang(elts, l, dflt) for l in langs] if len(lst) > 0: @@ -486,7 +492,7 @@ def etag(s): return hex_digest(s, hn="sha256") -def hash_id(entity, hn='sha1', prefix=True): +def hash_id(entity: Element, hn: str = 'sha1', prefix: bool = True) -> str: entity_id = entity if hasattr(entity, 'get'): entity_id = entity.get('entityID') @@ -657,7 +663,7 @@ def guess_entity_software(e): return 'other' -def is_text(x): +def is_text(x: Any) -> bool: return isinstance(x, six.string_types) or isinstance(x, six.text_type) @@ -773,7 +779,7 @@ def img_to_data(data: bytes, content_type: str) -> Optional[str]: assert data64 mime_type = "image/png" except BaseException as ex: - log.warning(ex) + log.warning(f'Exception when making Image: {ex}') log.debug(traceback.format_exc()) if data64 is None or len(data64) == 0: @@ -786,11 +792,11 @@ def short_id(data): return base64.urlsafe_b64encode(hasher.digest()[0:10]).rstrip('=') -def unicode_stream(data): +def unicode_stream(data: str) -> io.BytesIO: return six.BytesIO(data.encode('UTF-8')) -def b2u(data): +def b2u(data: Union[str, bytes, Tuple, List, Set]) -> Union[str, bytes, Tuple, List, Set]: if is_text(data): return data elif isinstance(data, six.binary_type):