|
| 1 | +import itertools |
| 2 | +from typing import Sequence, Mapping |
| 3 | +from comfy.graph import DynamicPrompt |
| 4 | + |
| 5 | +import nodes |
| 6 | + |
| 7 | +from comfy.graph_utils import is_link |
| 8 | + |
| 9 | +class CacheKeySet: |
| 10 | + def __init__(self, dynprompt, node_ids, is_changed_cache): |
| 11 | + self.keys = {} |
| 12 | + self.subcache_keys = {} |
| 13 | + |
| 14 | + def add_keys(self, node_ids): |
| 15 | + raise NotImplementedError() |
| 16 | + |
| 17 | + def all_node_ids(self): |
| 18 | + return set(self.keys.keys()) |
| 19 | + |
| 20 | + def get_used_keys(self): |
| 21 | + return self.keys.values() |
| 22 | + |
| 23 | + def get_used_subcache_keys(self): |
| 24 | + return self.subcache_keys.values() |
| 25 | + |
| 26 | + def get_data_key(self, node_id): |
| 27 | + return self.keys.get(node_id, None) |
| 28 | + |
| 29 | + def get_subcache_key(self, node_id): |
| 30 | + return self.subcache_keys.get(node_id, None) |
| 31 | + |
| 32 | +class Unhashable: |
| 33 | + def __init__(self): |
| 34 | + self.value = float("NaN") |
| 35 | + |
| 36 | +def to_hashable(obj): |
| 37 | + # So that we don't infinitely recurse since frozenset and tuples |
| 38 | + # are Sequences. |
| 39 | + if isinstance(obj, (int, float, str, bool, type(None))): |
| 40 | + return obj |
| 41 | + elif isinstance(obj, Mapping): |
| 42 | + return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())]) |
| 43 | + elif isinstance(obj, Sequence): |
| 44 | + return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) |
| 45 | + else: |
| 46 | + # TODO - Support other objects like tensors? |
| 47 | + return Unhashable() |
| 48 | + |
| 49 | +class CacheKeySetID(CacheKeySet): |
| 50 | + def __init__(self, dynprompt, node_ids, is_changed_cache): |
| 51 | + super().__init__(dynprompt, node_ids, is_changed_cache) |
| 52 | + self.dynprompt = dynprompt |
| 53 | + self.add_keys(node_ids) |
| 54 | + |
| 55 | + def add_keys(self, node_ids): |
| 56 | + for node_id in node_ids: |
| 57 | + if node_id in self.keys: |
| 58 | + continue |
| 59 | + node = self.dynprompt.get_node(node_id) |
| 60 | + self.keys[node_id] = (node_id, node["class_type"]) |
| 61 | + self.subcache_keys[node_id] = (node_id, node["class_type"]) |
| 62 | + |
| 63 | +class CacheKeySetInputSignature(CacheKeySet): |
| 64 | + def __init__(self, dynprompt, node_ids, is_changed_cache): |
| 65 | + super().__init__(dynprompt, node_ids, is_changed_cache) |
| 66 | + self.dynprompt = dynprompt |
| 67 | + self.is_changed_cache = is_changed_cache |
| 68 | + self.add_keys(node_ids) |
| 69 | + |
| 70 | + def include_node_id_in_input(self) -> bool: |
| 71 | + return False |
| 72 | + |
| 73 | + def add_keys(self, node_ids): |
| 74 | + for node_id in node_ids: |
| 75 | + if node_id in self.keys: |
| 76 | + continue |
| 77 | + node = self.dynprompt.get_node(node_id) |
| 78 | + self.keys[node_id] = self.get_node_signature(self.dynprompt, node_id) |
| 79 | + self.subcache_keys[node_id] = (node_id, node["class_type"]) |
| 80 | + |
| 81 | + def get_node_signature(self, dynprompt, node_id): |
| 82 | + signature = [] |
| 83 | + ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id) |
| 84 | + signature.append(self.get_immediate_node_signature(dynprompt, node_id, order_mapping)) |
| 85 | + for ancestor_id in ancestors: |
| 86 | + signature.append(self.get_immediate_node_signature(dynprompt, ancestor_id, order_mapping)) |
| 87 | + return to_hashable(signature) |
| 88 | + |
| 89 | + def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping): |
| 90 | + node = dynprompt.get_node(node_id) |
| 91 | + class_type = node["class_type"] |
| 92 | + class_def = nodes.NODE_CLASS_MAPPINGS[class_type] |
| 93 | + signature = [class_type, self.is_changed_cache.get(node_id)] |
| 94 | + if self.include_node_id_in_input() or (hasattr(class_def, "NOT_IDEMPOTENT") and class_def.NOT_IDEMPOTENT): |
| 95 | + signature.append(node_id) |
| 96 | + inputs = node["inputs"] |
| 97 | + for key in sorted(inputs.keys()): |
| 98 | + if is_link(inputs[key]): |
| 99 | + (ancestor_id, ancestor_socket) = inputs[key] |
| 100 | + ancestor_index = ancestor_order_mapping[ancestor_id] |
| 101 | + signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket))) |
| 102 | + else: |
| 103 | + signature.append((key, inputs[key])) |
| 104 | + return signature |
| 105 | + |
| 106 | + # This function returns a list of all ancestors of the given node. The order of the list is |
| 107 | + # deterministic based on which specific inputs the ancestor is connected by. |
| 108 | + def get_ordered_ancestry(self, dynprompt, node_id): |
| 109 | + ancestors = [] |
| 110 | + order_mapping = {} |
| 111 | + self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping) |
| 112 | + return ancestors, order_mapping |
| 113 | + |
| 114 | + def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping): |
| 115 | + inputs = dynprompt.get_node(node_id)["inputs"] |
| 116 | + input_keys = sorted(inputs.keys()) |
| 117 | + for key in input_keys: |
| 118 | + if is_link(inputs[key]): |
| 119 | + ancestor_id = inputs[key][0] |
| 120 | + if ancestor_id not in order_mapping: |
| 121 | + ancestors.append(ancestor_id) |
| 122 | + order_mapping[ancestor_id] = len(ancestors) - 1 |
| 123 | + self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) |
| 124 | + |
| 125 | +class BasicCache: |
| 126 | + def __init__(self, key_class): |
| 127 | + self.key_class = key_class |
| 128 | + self.initialized = False |
| 129 | + self.dynprompt: DynamicPrompt |
| 130 | + self.cache_key_set: CacheKeySet |
| 131 | + self.cache = {} |
| 132 | + self.subcaches = {} |
| 133 | + |
| 134 | + def set_prompt(self, dynprompt, node_ids, is_changed_cache): |
| 135 | + self.dynprompt = dynprompt |
| 136 | + self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) |
| 137 | + self.is_changed_cache = is_changed_cache |
| 138 | + self.initialized = True |
| 139 | + |
| 140 | + def all_node_ids(self): |
| 141 | + assert self.initialized |
| 142 | + node_ids = self.cache_key_set.all_node_ids() |
| 143 | + for subcache in self.subcaches.values(): |
| 144 | + node_ids = node_ids.union(subcache.all_node_ids()) |
| 145 | + return node_ids |
| 146 | + |
| 147 | + def _clean_cache(self): |
| 148 | + preserve_keys = set(self.cache_key_set.get_used_keys()) |
| 149 | + to_remove = [] |
| 150 | + for key in self.cache: |
| 151 | + if key not in preserve_keys: |
| 152 | + to_remove.append(key) |
| 153 | + for key in to_remove: |
| 154 | + del self.cache[key] |
| 155 | + |
| 156 | + def _clean_subcaches(self): |
| 157 | + preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys()) |
| 158 | + |
| 159 | + to_remove = [] |
| 160 | + for key in self.subcaches: |
| 161 | + if key not in preserve_subcaches: |
| 162 | + to_remove.append(key) |
| 163 | + for key in to_remove: |
| 164 | + del self.subcaches[key] |
| 165 | + |
| 166 | + def clean_unused(self): |
| 167 | + assert self.initialized |
| 168 | + self._clean_cache() |
| 169 | + self._clean_subcaches() |
| 170 | + |
| 171 | + def _set_immediate(self, node_id, value): |
| 172 | + assert self.initialized |
| 173 | + cache_key = self.cache_key_set.get_data_key(node_id) |
| 174 | + self.cache[cache_key] = value |
| 175 | + |
| 176 | + def _get_immediate(self, node_id): |
| 177 | + if not self.initialized: |
| 178 | + return None |
| 179 | + cache_key = self.cache_key_set.get_data_key(node_id) |
| 180 | + if cache_key in self.cache: |
| 181 | + return self.cache[cache_key] |
| 182 | + else: |
| 183 | + return None |
| 184 | + |
| 185 | + def _ensure_subcache(self, node_id, children_ids): |
| 186 | + subcache_key = self.cache_key_set.get_subcache_key(node_id) |
| 187 | + subcache = self.subcaches.get(subcache_key, None) |
| 188 | + if subcache is None: |
| 189 | + subcache = BasicCache(self.key_class) |
| 190 | + self.subcaches[subcache_key] = subcache |
| 191 | + subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) |
| 192 | + return subcache |
| 193 | + |
| 194 | + def _get_subcache(self, node_id): |
| 195 | + assert self.initialized |
| 196 | + subcache_key = self.cache_key_set.get_subcache_key(node_id) |
| 197 | + if subcache_key in self.subcaches: |
| 198 | + return self.subcaches[subcache_key] |
| 199 | + else: |
| 200 | + return None |
| 201 | + |
| 202 | + def recursive_debug_dump(self): |
| 203 | + result = [] |
| 204 | + for key in self.cache: |
| 205 | + result.append({"key": key, "value": self.cache[key]}) |
| 206 | + for key in self.subcaches: |
| 207 | + result.append({"subcache_key": key, "subcache": self.subcaches[key].recursive_debug_dump()}) |
| 208 | + return result |
| 209 | + |
| 210 | +class HierarchicalCache(BasicCache): |
| 211 | + def __init__(self, key_class): |
| 212 | + super().__init__(key_class) |
| 213 | + |
| 214 | + def _get_cache_for(self, node_id): |
| 215 | + assert self.dynprompt is not None |
| 216 | + parent_id = self.dynprompt.get_parent_node_id(node_id) |
| 217 | + if parent_id is None: |
| 218 | + return self |
| 219 | + |
| 220 | + hierarchy = [] |
| 221 | + while parent_id is not None: |
| 222 | + hierarchy.append(parent_id) |
| 223 | + parent_id = self.dynprompt.get_parent_node_id(parent_id) |
| 224 | + |
| 225 | + cache = self |
| 226 | + for parent_id in reversed(hierarchy): |
| 227 | + cache = cache._get_subcache(parent_id) |
| 228 | + if cache is None: |
| 229 | + return None |
| 230 | + return cache |
| 231 | + |
| 232 | + def get(self, node_id): |
| 233 | + cache = self._get_cache_for(node_id) |
| 234 | + if cache is None: |
| 235 | + return None |
| 236 | + return cache._get_immediate(node_id) |
| 237 | + |
| 238 | + def set(self, node_id, value): |
| 239 | + cache = self._get_cache_for(node_id) |
| 240 | + assert cache is not None |
| 241 | + cache._set_immediate(node_id, value) |
| 242 | + |
| 243 | + def ensure_subcache_for(self, node_id, children_ids): |
| 244 | + cache = self._get_cache_for(node_id) |
| 245 | + assert cache is not None |
| 246 | + return cache._ensure_subcache(node_id, children_ids) |
| 247 | + |
| 248 | +class LRUCache(BasicCache): |
| 249 | + def __init__(self, key_class, max_size=100): |
| 250 | + super().__init__(key_class) |
| 251 | + self.max_size = max_size |
| 252 | + self.min_generation = 0 |
| 253 | + self.generation = 0 |
| 254 | + self.used_generation = {} |
| 255 | + self.children = {} |
| 256 | + |
| 257 | + def set_prompt(self, dynprompt, node_ids, is_changed_cache): |
| 258 | + super().set_prompt(dynprompt, node_ids, is_changed_cache) |
| 259 | + self.generation += 1 |
| 260 | + for node_id in node_ids: |
| 261 | + self._mark_used(node_id) |
| 262 | + |
| 263 | + def clean_unused(self): |
| 264 | + while len(self.cache) > self.max_size and self.min_generation < self.generation: |
| 265 | + self.min_generation += 1 |
| 266 | + to_remove = [key for key in self.cache if self.used_generation[key] < self.min_generation] |
| 267 | + for key in to_remove: |
| 268 | + del self.cache[key] |
| 269 | + del self.used_generation[key] |
| 270 | + if key in self.children: |
| 271 | + del self.children[key] |
| 272 | + self._clean_subcaches() |
| 273 | + |
| 274 | + def get(self, node_id): |
| 275 | + self._mark_used(node_id) |
| 276 | + return self._get_immediate(node_id) |
| 277 | + |
| 278 | + def _mark_used(self, node_id): |
| 279 | + cache_key = self.cache_key_set.get_data_key(node_id) |
| 280 | + if cache_key is not None: |
| 281 | + self.used_generation[cache_key] = self.generation |
| 282 | + |
| 283 | + def set(self, node_id, value): |
| 284 | + self._mark_used(node_id) |
| 285 | + return self._set_immediate(node_id, value) |
| 286 | + |
| 287 | + def ensure_subcache_for(self, node_id, children_ids): |
| 288 | + # Just uses subcaches for tracking 'live' nodes |
| 289 | + super()._ensure_subcache(node_id, children_ids) |
| 290 | + |
| 291 | + self.cache_key_set.add_keys(children_ids) |
| 292 | + self._mark_used(node_id) |
| 293 | + cache_key = self.cache_key_set.get_data_key(node_id) |
| 294 | + self.children[cache_key] = [] |
| 295 | + for child_id in children_ids: |
| 296 | + self._mark_used(child_id) |
| 297 | + self.children[cache_key].append(self.cache_key_set.get_data_key(child_id)) |
| 298 | + return self |
| 299 | + |
0 commit comments