diff --git a/python/tvm/contrib/tvmjs.py b/python/tvm/contrib/tvmjs.py index 2a7604c0ada2..9bff724df7bc 100644 --- a/python/tvm/contrib/tvmjs.py +++ b/python/tvm/contrib/tvmjs.py @@ -35,6 +35,7 @@ import tvm from tvm._ffi.libinfo import find_lib_path +from tvm.runtime import DataType from .emcc import create_tvmjs_wasm @@ -276,7 +277,13 @@ def dump_ndarray_cache( v = v.numpy() # prefer to preserve original dtype, especially if the format was bfloat16 - dtype = str(origin_v.dtype) if isinstance(origin_v, tvm.nd.NDArray) else str(v.dtype) + dtype = origin_v.dtype if isinstance(origin_v, tvm.nd.NDArray) else v.dtype + + if dtype in DataType.NUMPY2STR: + dtype = DataType.NUMPY2STR[dtype] + else: + dtype = str(dtype) + total_bytes += math.prod(v.shape) * np.dtype(v.dtype).itemsize # convert fp32 to bf16 diff --git a/tests/python/contrib/test_tvmjs.py b/tests/python/contrib/test_tvmjs.py new file mode 100644 index 000000000000..22742ec224ef --- /dev/null +++ b/tests/python/contrib/test_tvmjs.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test contrib.tvmjs""" + +import tempfile + +import numpy as np +import pytest + +import tvm.testing +from tvm.contrib import tvmjs + +dtype = tvm.testing.parameter( + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", + "float16", + "float32", + "float64", + "float8_e4m3fn", + "float8_e5m2", +) + + +def test_save_load_float8(dtype): + if "float8" in dtype or "bfloat16" in dtype: + ml_dtypes = pytest.importorskip("ml_dtypes") + np_dtype = np.dtype(getattr(ml_dtypes, dtype)) + else: + np_dtype = np.dtype(dtype) + + arr = np.arange(16, dtype=np_dtype) + + with tempfile.TemporaryDirectory(prefix="tvm_") as temp_dir: + tvmjs.dump_ndarray_cache({"arr": arr}, temp_dir) + cache, _ = tvmjs.load_ndarray_cache(temp_dir, tvm.cpu()) + + after_roundtrip = cache["arr"].numpy() + + np.testing.assert_array_equal(arr, after_roundtrip) + + +if __name__ == "__main__": + tvm.testing.main()