|
1 |
| -from collections.abc import Collection, Mapping, Sequence |
| 1 | +import os |
| 2 | +from collections import defaultdict |
| 3 | +from collections.abc import Mapping, MutableMapping, MutableSequence |
| 4 | +from contextlib import contextmanager |
| 5 | +from copy import deepcopy |
| 6 | +from dataclasses import dataclass, field, replace |
| 7 | +from typing import Any, List, Optional, Sequence, Union |
2 | 8 |
|
3 |
| -# for testing purpose |
4 |
| -# FIXME: after implementing of reading of "params". |
5 |
| -TEST_DATA = { |
6 |
| - "__test__": { |
7 |
| - "dict": {"one": 1, "two": 2, "three": "three", "four": "4"}, |
8 |
| - "list": [1, 2, 3, 4, 3.14], |
9 |
| - "set": {1, 2, 3}, |
10 |
| - "tuple": (1, 2), |
11 |
| - "bool": True, |
12 |
| - "none": None, |
13 |
| - "float": 3.14, |
14 |
| - "nomnom": 1000, |
15 |
| - } |
16 |
| -} |
| 9 | +from funcy import identity |
17 | 10 |
|
| 11 | +from dvc.utils.serialize import LOADERS |
18 | 12 |
|
19 |
| -class Context: |
20 |
| - def __init__(self, data=None): |
21 |
| - self.data = data or TEST_DATA |
22 | 13 |
|
23 |
| - def select(self, key): |
24 |
| - return _get_value(self.data, key) |
| 14 | +def _merge(into, update, overwrite): |
| 15 | + for key, val in update.items(): |
| 16 | + if isinstance(into.get(key), Mapping) and isinstance(val, Mapping): |
| 17 | + _merge(into[key], val, overwrite) |
| 18 | + else: |
| 19 | + if key in into and not overwrite: |
| 20 | + raise ValueError( |
| 21 | + f"Cannot overwrite as key {key} already exists in {into}" |
| 22 | + ) |
| 23 | + into[key] = val |
25 | 24 |
|
26 | 25 |
|
27 |
| -def _get_item(data, idx): |
28 |
| - if isinstance(data, Sequence): |
29 |
| - idx = int(idx) |
| 26 | +@dataclass |
| 27 | +class Meta: |
| 28 | + source: Optional[str] = None |
| 29 | + dpaths: List[str] = field(default_factory=list) |
30 | 30 |
|
31 |
| - if isinstance(data, (Mapping, Sequence)): |
32 |
| - return data[idx] |
| 31 | + @staticmethod |
| 32 | + def update_path(meta: "Meta", path: Union[str, int]): |
| 33 | + dpaths = meta.dpaths[:] + [str(path)] |
| 34 | + return replace(meta, dpaths=dpaths) |
33 | 35 |
|
34 |
| - raise ValueError( |
35 |
| - f"Cannot get item '{idx}' from data of type '{type(data).__name__}'" |
| 36 | + def __str__(self): |
| 37 | + string = self.source or "<local>:" |
| 38 | + string += self.path() |
| 39 | + return string |
| 40 | + |
| 41 | + def path(self): |
| 42 | + return ".".join(self.dpaths) |
| 43 | + |
| 44 | + |
| 45 | +def _default_meta(): |
| 46 | + return Meta(source=None) |
| 47 | + |
| 48 | + |
| 49 | +@dataclass |
| 50 | +class Value: |
| 51 | + value: Any |
| 52 | + meta: Meta = field( |
| 53 | + compare=False, default_factory=_default_meta, repr=False |
36 | 54 | )
|
37 | 55 |
|
| 56 | + def __repr__(self): |
| 57 | + return repr(self.value) |
38 | 58 |
|
39 |
| -def _get_value(data, key): |
40 |
| - obj_and_attrs = key.strip().split(".") |
41 |
| - value = data |
42 |
| - for attr in obj_and_attrs: |
43 |
| - if attr == "": |
44 |
| - raise ValueError("Syntax error!") |
| 59 | + def get_sources(self): |
| 60 | + return {self.meta.source: self.meta.path()} |
45 | 61 |
|
46 |
| - try: |
47 |
| - value = _get_item(value, attr) |
48 |
| - except KeyError: |
| 62 | + |
| 63 | +class Container: |
| 64 | + meta: Meta |
| 65 | + data: Union[list, dict] |
| 66 | + _key_transform = staticmethod(identity) |
| 67 | + |
| 68 | + def __init__(self, meta) -> None: |
| 69 | + self.meta = meta or Meta(source=None) |
| 70 | + |
| 71 | + def _convert(self, key, value): |
| 72 | + meta = Meta.update_path(self.meta, key) |
| 73 | + if value is None or isinstance(value, (int, float, str, bytes, bool)): |
| 74 | + return Value(value, meta=meta) |
| 75 | + elif isinstance(value, (CtxList, CtxDict, Value)): |
| 76 | + return value |
| 77 | + elif isinstance(value, (list, dict)): |
| 78 | + container = CtxDict if isinstance(value, dict) else CtxList |
| 79 | + return container(value, meta=meta) |
| 80 | + else: |
49 | 81 | msg = (
|
50 |
| - f"Could not find '{attr}' " |
51 |
| - "while substituting " |
52 |
| - f"'{key}'.\n" |
53 |
| - f"Interpolating with: {data}" |
| 82 | + "Unsupported value of type " |
| 83 | + f"'{type(value).__name__}' in '{meta}'" |
54 | 84 | )
|
55 |
| - raise ValueError(msg) |
| 85 | + raise TypeError(msg) |
| 86 | + |
| 87 | + def __repr__(self): |
| 88 | + return repr(self.data) |
| 89 | + |
| 90 | + def __getitem__(self, key): |
| 91 | + return self.data[key] |
| 92 | + |
| 93 | + def __setitem__(self, key, value): |
| 94 | + self.data[key] = self._convert(key, value) |
| 95 | + |
| 96 | + def __delitem__(self, key): |
| 97 | + del self.data[key] |
| 98 | + |
| 99 | + def __len__(self): |
| 100 | + return len(self.data) |
| 101 | + |
| 102 | + def __iter__(self): |
| 103 | + return iter(self.data) |
| 104 | + |
| 105 | + def __eq__(self, o): |
| 106 | + return o.data == self.data |
| 107 | + |
| 108 | + def select(self, key: str): |
| 109 | + index, *rems = key.split(sep=".", maxsplit=1) |
| 110 | + index = index.strip() |
| 111 | + index = self._key_transform(index) |
| 112 | + try: |
| 113 | + d = self.data[index] |
| 114 | + except LookupError as exc: |
| 115 | + raise ValueError( |
| 116 | + f"Could not find '{index}' in {self.data}" |
| 117 | + ) from exc |
| 118 | + return d.select(rems[0]) if rems else d |
| 119 | + |
| 120 | + def get_sources(self): |
| 121 | + return {} |
| 122 | + |
| 123 | + |
| 124 | +class CtxList(Container, MutableSequence): |
| 125 | + _key_transform = staticmethod(int) |
| 126 | + |
| 127 | + def __init__(self, values: Sequence, meta: Meta = None): |
| 128 | + super().__init__(meta=meta) |
| 129 | + self.data: list = [] |
| 130 | + self.extend(values) |
| 131 | + |
| 132 | + def insert(self, index: int, value): |
| 133 | + self.data.insert(index, self._convert(index, value)) |
| 134 | + |
| 135 | + def get_sources(self): |
| 136 | + return {self.meta.source: self.meta.path()} |
| 137 | + |
| 138 | + |
| 139 | +class CtxDict(Container, MutableMapping): |
| 140 | + def __init__(self, mapping: Mapping = None, meta: Meta = None, **kwargs): |
| 141 | + super().__init__(meta=meta) |
| 142 | + |
| 143 | + self.data: dict = {} |
| 144 | + if mapping: |
| 145 | + self.update(mapping) |
| 146 | + self.update(kwargs) |
| 147 | + |
| 148 | + def __setitem__(self, key, value): |
| 149 | + if not isinstance(key, str): |
| 150 | + # limitation for the interpolation |
| 151 | + # ignore other kinds of keys |
| 152 | + return |
| 153 | + return super().__setitem__(key, value) |
| 154 | + |
| 155 | + def merge_update(self, *args, overwrite=False): |
| 156 | + for d in args: |
| 157 | + _merge(self.data, d, overwrite=overwrite) |
| 158 | + |
| 159 | + |
| 160 | +class Context(CtxDict): |
| 161 | + def __init__(self, *args, **kwargs): |
| 162 | + """ |
| 163 | + Top level mutable dict, with some helpers to create context and track |
| 164 | + """ |
| 165 | + super().__init__(*args, **kwargs) |
| 166 | + self._track = False |
| 167 | + self._tracked_data = defaultdict(set) |
| 168 | + |
| 169 | + @contextmanager |
| 170 | + def track(self): |
| 171 | + self._track = True |
| 172 | + yield |
| 173 | + self._track = False |
| 174 | + |
| 175 | + def _track_data(self, node): |
| 176 | + if not self._track: |
| 177 | + return |
| 178 | + |
| 179 | + for source, keys in node.get_sources().items(): |
| 180 | + if not source: |
| 181 | + continue |
| 182 | + params_file = self._tracked_data[source] |
| 183 | + keys = [keys] if isinstance(keys, str) else keys |
| 184 | + params_file.update(keys) |
| 185 | + |
| 186 | + @property |
| 187 | + def tracked(self): |
| 188 | + return [ |
| 189 | + {file: list(keys)} for file, keys in self._tracked_data.items() |
| 190 | + ] |
| 191 | + |
| 192 | + def select(self, key: str): |
| 193 | + node = super().select(key) |
| 194 | + self._track_data(node) |
| 195 | + return node |
| 196 | + |
| 197 | + @classmethod |
| 198 | + def load_from(cls, tree, file: str) -> "Context": |
| 199 | + _, ext = os.path.splitext(file) |
| 200 | + loader = LOADERS[ext] |
| 201 | + |
| 202 | + meta = Meta(source=file) |
| 203 | + return cls(loader(file, tree=tree), meta=meta) |
56 | 204 |
|
57 |
| - if not isinstance(value, str) and isinstance(value, Collection): |
58 |
| - raise ValueError( |
59 |
| - f"Cannot interpolate value of type '{type(value).__name__}'" |
60 |
| - ) |
61 |
| - return value |
| 205 | + @classmethod |
| 206 | + def clone(cls, ctx: "Context") -> "Context": |
| 207 | + """Clones given context.""" |
| 208 | + return cls(deepcopy(ctx.data)) |
0 commit comments