From 0e9aaf1981ee9324276d90531529e747b93b3473 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Sat, 25 May 2024 21:48:54 -0700 Subject: [PATCH] [Contrib] Implement NDArray cache update --- python/tvm/contrib/tvmjs.py | 76 ++++++++++++++++++++-- tests/python/relax/test_runtime_builtin.py | 25 +++++++ 2 files changed, 94 insertions(+), 7 deletions(-) diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index 923301a1f509..2a7604c0ada2 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -24,7 +24,7 @@ # pylint: disable=unused-import import sys from types import GeneratorType -from typing import Iterator, Mapping, Tuple, Union +from typing import Any, Iterator, Mapping, Optional, Set, Tuple, Union import numpy as np @@ -73,7 +73,13 @@ def _calculate_md5(filename): class NDArrayCacheShardingManager: """Internal helper to shard ndarrays.""" - def __init__(self, cache_dir: str, prefix: str, shard_cap_nbytes: int): + def __init__( + self, + cache_dir: str, + prefix: str, + shard_cap_nbytes: int, + initial_shard_records: Optional[Mapping[str, Any]] = None, + ): self.cache_dir = cache_dir self.prefix = prefix self.curr_records = [] @@ -81,8 +87,17 @@ def __init__(self, cache_dir: str, prefix: str, shard_cap_nbytes: int): self.shard_records = [] self.shard_cap_nbytes = shard_cap_nbytes self.counter = 0 + self.name_to_record: Mapping[str, Tuple[int, Mapping[str, Any]]] = {} + self.updated_shards: Set[int] = set() - def append(self, data, name, shape, dtype, encode_format): + if initial_shard_records is not None: + self.shard_records = initial_shard_records + self.counter = len(initial_shard_records) + for idx, shard in enumerate(initial_shard_records): + for rec in shard["records"]: + self.name_to_record[rec["name"]] = (idx, rec) + + def append_or_update(self, data, name, shape, dtype, encode_format, allow_update: bool = False): """Commit a record to the manager. Parameters @@ -101,6 +116,9 @@ def append(self, data, name, shape, dtype, encode_format): encode_format: The encode format of the entry + + allow_update: bool + If the record already exists, update the record. Otherwise, raise an error. """ rec = { "name": name, @@ -109,6 +127,13 @@ def append(self, data, name, shape, dtype, encode_format): "format": encode_format, "nbytes": len(data), } + if name in self.name_to_record: + if not allow_update: + raise ValueError(f"Duplicate name {name} found in the cache.") + self.update_single_record(rec, data) + return + + self.name_to_record[name] = (self.counter, rec) if self.pending_nbytes + len(data) >= self.shard_cap_nbytes: if len(data) * 2 >= self.shard_cap_nbytes: @@ -121,6 +146,20 @@ def append(self, data, name, shape, dtype, encode_format): self.curr_records.append(rec) self.curr_data += data + def update_single_record(self, rec, data): + """Update a single record in a shard file.""" + name = rec["name"] + idx, old_rec = self.name_to_record[name] + if old_rec["nbytes"] != rec["nbytes"]: + raise ValueError(f"Cannot update record {name}, size mismatch.") + data_path = self.shard_records[idx]["dataPath"] + full_path = os.path.join(self.cache_dir, data_path) + with open(full_path, "r+b") as outfile: + outfile.seek(old_rec["byteOffset"]) + outfile.write(data) + self.name_to_record[name] = (idx, rec) + self.updated_shards.add(idx) + def commit(self): """Commit a record""" if self.pending_nbytes != 0: @@ -131,6 +170,9 @@ def commit(self): def finish(self): """Finish building and return shard records.""" self.commit() + for idx in self.updated_shards: + full_path = os.path.join(self.cache_dir, self.shard_records[idx]["dataPath"]) + self.shard_records[idx]["md5sum"] = _calculate_md5(full_path) return self.shard_records def _commit_internal(self, data, records): @@ -165,6 +207,7 @@ def dump_ndarray_cache( meta_data=None, shard_cap_mb=32, show_progress: bool = True, + update_if_exists: bool = False, ): """Dump parameters to NDArray cache. @@ -191,6 +234,10 @@ def dump_ndarray_cache( show_progress: bool A boolean indicating if to show the dump progress. + + update_if_exists: bool + If the cache already exists, update the cache. When set to False, it will overwrite the + existing files. """ if encode_format not in ("raw", "f32-to-bf16"): raise ValueError(f"Invalie encode_format {encode_format}") @@ -209,7 +256,17 @@ def dump_ndarray_cache( print("Start storing to cache %s" % cache_dir) shard_cap_nbytes = shard_cap_mb * (1 << 20) - shard_manager = NDArrayCacheShardingManager(cache_dir, "params_shard", shard_cap_nbytes) + nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json") + if update_if_exists and os.path.exists(nd_cache_json): + with open(nd_cache_json, "r") as infile: + old_data = json.load(infile) + if meta_data is None: + meta_data = old_data["metadata"] + records = old_data["records"] + + shard_manager = NDArrayCacheShardingManager( + cache_dir, "params_shard", shard_cap_nbytes, initial_shard_records=records + ) param_generator = params.items() if not from_generator else params for k, origin_v in param_generator: @@ -229,7 +286,14 @@ def dump_ndarray_cache( else: data = v.tobytes() - shard_manager.append(data, name=k, shape=shape, dtype=dtype, encode_format=encode_format) + shard_manager.append_or_update( + data, + name=k, + shape=shape, + dtype=dtype, + encode_format=encode_format, + allow_update=update_if_exists, + ) counter += 1 if show_progress: @@ -241,8 +305,6 @@ def dump_ndarray_cache( records = shard_manager.finish() meta_data = {} if meta_data is None else meta_data if not callable(meta_data) else meta_data() - nd_cache_json = os.path.join(cache_dir, "ndarray-cache.json") - with open(nd_cache_json, "w") as outfile: json.dump({"metadata": meta_data, "records": records}, outfile, indent=4) print( diff --git a/tests/python/relax/test_runtime_builtin.py b/tests/python/relax/test_runtime_builtin.py index 614d32ce0c7d..fb4c8abdf9e6 100644 --- a/tests/python/relax/test_runtime_builtin.py +++ b/tests/python/relax/test_runtime_builtin.py @@ -188,6 +188,31 @@ def test_ndarray_cache(): np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6) +def test_ndarray_cache_update(): + fload = tvm.get_global_func("vm.builtin.ndarray_cache.load") + fget_params = tvm.get_global_func("vm.builtin.param_array_from_cache") + + param_dict = { + "x_0": np.array([1, 2, 3], dtype="int32"), + "x_1": np.random.uniform(size=[10, 20]).astype("float32"), + } + + temp = utils.tempdir() + tvmjs.dump_ndarray_cache(param_dict, temp.path, encode_format="f32-to-bf16") + param_dict["x_1"] = np.random.uniform(size=[10, 20]).astype("float32") + param_dict["x_2"] = np.random.uniform(size=[10]).astype("float32") + tvmjs.dump_ndarray_cache( + param_dict, temp.path, encode_format="f32-to-bf16", update_if_exists=True + ) + fload(str(temp.path), tvm.cpu().device_type, 0) + res = fget_params("x", -1) + for i, v in enumerate(res): + v_np = param_dict[f"x_{i}"] + if v_np.dtype == "float32": + v_np = tvmjs._convert_bf16_to_f32(tvmjs._convert_f32_to_bf16(v_np)) + np.testing.assert_allclose(v.numpy(), v_np, atol=1e-6, rtol=1e-6) + + def test_attention_kv_cache_window_override(): fcreate = tvm.get_global_func("vm.builtin.attention_kv_cache_create") foverride = tvm.get_global_func("vm.builtin.attention_kv_cache_window_override")