diff --git a/fsspec/implementations/reference.py b/fsspec/implementations/reference.py index 2da6a4fb0..2ac555b9c 100644 --- a/fsspec/implementations/reference.py +++ b/fsspec/implementations/reference.py @@ -1,3 +1,5 @@ +import base64 +import itertools import json from ..asyn import AsyncFileSystem @@ -70,10 +72,11 @@ def __init__( raise NotImplementedError("Only works with async targets") if isinstance(references, str): with open(references, "rb", **(ref_storage_args or {})) as f: - references = json.load(f) - self.references = references + text = f.read() + else: + text = references self.target = target - self._process_references() + self._process_references(text) self.fs = fs async def _cat_file(self, path): @@ -89,9 +92,72 @@ async def _cat_file(self, path): url = self.target return await self.fs._cat_file(url, start=start, end=end) - def _process_references(self): - if "zarr_consolidated_format" in self.references: - self.references = _unmodel_hdf5(self.references) + def _process_references(self, references): + if isinstance(references, bytes): + references = json.loads(references.decode()) + vers = references.get("version", None) + if vers is None: + self._process_references0(references) + elif vers == 1: + self._process_references1(references) + else: + raise ValueError(f"Unknown reference spec version: {vers}") + # TODO: we make dircache by iteraring over all entries, but for Spec >= 1, + # can replace with programmatic. Is it even needed for mapper interface? + self._dircache_from_items() + + def _process_references0(self, references): + """Make reference dict for Spec Version 0""" + if "zarr_consolidated_format" in references: + # special case for Ike prototype + references = _unmodel_hdf5(references) + self.references = references + + def _process_references1(self, references): + try: + import jinja2 + except ImportError as e: + raise ValueError("Reference Spec Version 1 requires jinja2") from e + self.references = {} + templates = {} + for k, v in references.get("templates", {}).items(): + if "{{" in v: + templates[k] = lambda temp=v, **kwargs: jinja2.Template(temp).render( + **kwargs + ) + else: + templates[k] = v + + for k, v in references["refs"].items(): + if isinstance(v, str): + if v.startswith("base64:"): + self.references[k] = base64.b64decode(v[7:]) + self.references[k] = v + else: + u, off, l = v + if "{{" in u: + u = jinja2.Template(u).render(**templates) + self.references[k] = [u, off, l] + for gen in references.get("gen", []): + dimension = { + k: v + if isinstance(v, list) + else range(v.get("start", 0), v["stop"], v.get("step", 1)) + for k, v in gen["dimensions"].items() + } + products = ( + dict(zip(dimension.keys(), values)) + for values in itertools.product(*dimension.values()) + ) + for pr in products: + key = jinja2.Template(gen["key"]).render(**pr, **templates) + url = jinja2.Template(gen["url"]).render(**pr, **templates) + offset = int(jinja2.Template(gen["offset"]).render(**pr, **templates)) + length = int(jinja2.Template(gen["length"]).render(**pr, **templates)) + + self.references[key] = [url, offset, length] + + def _dircache_from_items(self): self.dircache = {"": []} for path, part in self.references.items(): if isinstance(part, (bytes, str)): diff --git a/fsspec/implementations/tests/test_reference.py b/fsspec/implementations/tests/test_reference.py index 0379d5f2d..4f8052906 100644 --- a/fsspec/implementations/tests/test_reference.py +++ b/fsspec/implementations/tests/test_reference.py @@ -91,3 +91,38 @@ def test_unmodel(): refs = _unmodel_hdf5(json.loads(jdata)) assert b'"Conventions": "UGRID-0.9.0"' in refs[".zattrs"] assert refs["adcirc_mesh/0"] == ("https://url", 8928, 8932) + + +def test_spec1_expand(): + pytest.importorskip("jinja2") + in_data = { + "version": 1, + "templates": {"u": "server.domain/path", "f": "{{c}}"}, + "gen": [ + { + "key": "gen_key{{i}}", + "url": "http://{{u}}_{{i}}", + "offset": "{{(i + 1) * 1000}}", + "length": "1000", + "dimensions": {"i": {"stop": 5}}, + } + ], + "refs": { + "key0": "data", + "key1": ["http://target_url", 10000, 100], + "key2": ["http://{{u}}", 10000, 100], + "key3": ["http://{{f(c='text')}}", 10000, 100], + }, + } + fs = fsspec.filesystem("reference", references=in_data, target_protocol="http") + assert fs.references == { + "key0": "data", + "key1": ["http://target_url", 10000, 100], + "key2": ["http://server.domain/path", 10000, 100], + "key3": ["http://text", 10000, 100], + "gen_key0": ["http://server.domain/path_0", 1000, 1000], + "gen_key1": ["http://server.domain/path_1", 2000, 1000], + "gen_key2": ["http://server.domain/path_2", 3000, 1000], + "gen_key3": ["http://server.domain/path_3", 4000, 1000], + "gen_key4": ["http://server.domain/path_4", 5000, 1000], + } diff --git a/tox.ini b/tox.ini index dbae9ebd1..dbe777b2c 100644 --- a/tox.ini +++ b/tox.ini @@ -33,6 +33,7 @@ conda_deps= python-libarchive-c numpy nomkl + jinja2 deps= hadoop-test-cluster==0.1.0 smbprotocol